[
  {
    "path": "LICENSE.txt",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2022 Salesforce, Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\n3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include lavis/configs *.yaml *.json\nrecursive-include lavis/projects *.yaml *.json\n\nrecursive-exclude lavis/datasets/download_scripts *\nrecursive-exclude lavis/output *\n\ninclude requirements.txt\n"
  },
  {
    "path": "README.md",
    "content": "# [NeurIPS 2023] Self-Chained Image-Language Model for Video Localization and Question Answering\n\n* Authors: [Shoubin Yu](https://yui010206.github.io/), [Jaemin Cho](https://j-min.io), [Prateek Yadav](https://prateek-yadav.github.io/), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/)\n* Paper: [arXiv](https://arxiv.org/abs/2305.06988)\n* Online Demo: Try our Gradio demo on Hugging Face[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Shoubin/SeViLA)\n\n<img src=\"./assets/teaser.png\" alt=\"teaser image\" width=\"800\"/>\n\n<img src=\"./assets/model.png\" alt=\"teaser image\" width=\"800\"/>\n\n<img src=\"./assets/chain.png\" alt=\"teaser image\" width=\"800\"/>\n\n\n# Code structure\n```bash\n\n# data & data preprocessing\n./sevila_data\n\n# pretrained checkpoints\n./sevila_checkpoints\n\n# SeViLA code\n./lavis/\n\n# running scripts for SeViLA localizer/answerer training/inference\n./run_scripts\n\n```\n\n# Setup\n\n## Install Dependencies\n\n1. (Optional) Creating conda environment\n\n```bash\nconda create -n sevila python=3.8\nconda activate sevila\n```\n\n2. build from source\n\n```bash\npip install -e .\n```\n\n## Download Pretrained Models\nWe pre-train SeViLA localizer on QVHighlights and hold checkpoints via [Hugging Face](https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth).\nDownload checkpoints and put it under /sevila_checkpoints.\nThe checkpoints (814.55M) contains pre-trained localizer and zero-shot answerer.\n\nIf you want to pre-train your own localizer, you can download [qformer_loc.pth](https://drive.google.com/file/d/13hE_BQDflkzYrHVmVGddRSt8VMa0ouGB/view?usp=sharing), which is a copy of the original BLIP-2 Q-former to initialize the localizer (with changed model keys).\n\n# Run Gradio Demo Locally\nWe also provide a UI for testing our SeViLA locally that is built with gradio. \nRunning demo locally requires about 12GB of memory.\n\n* Installing Gradio:\n\n```bash\npip install gradio==3.30.0\n```\n\n* Running the following command in a terminal will launch the demo:\n\n```bash\npython app.py\n```\n\n# Dataset Preparation\nWe test our model on:\n+ [NExT-QA](https://doc-doc.github.io/docs/nextqa.html)\n\n+ [STAR](https://star.csail.mit.edu/)\n\n+ [How2QA](https://value-benchmark.github.io/index.html)\n\n+ [TVQA](https://tvqa.cs.unc.edu/)\n\n+ [VLEP](https://value-benchmark.github.io/index.html)\n\n+ [QVHighlights](https://github.com/jayleicn/moment_detr)\n\nPlease download original QA data and preprocess them via our [scripts](sevila_data/).\n\n\n# Training and Inference\nWe provide SeViLA training and inference script examples as follows.\n\nAnd please refer to [dataset page](sevila_data/) to custom your data path.\n\n## 1) Localizer Pre-training\n```bash\nsh run_scripts/sevila/pre-train/pretrain_qvh.sh\n```\n\n## 2) Answerer Fine-tuning\n\n```bash\nsh run_scripts/sevila/finetune/nextqa_ft.sh\n```\n\n## 3) Localizer Self-refinement\n\n```bash\nsh run_scripts/sevila/refinement/nextqa_sr.sh\n```\n\n## 4) Inference\n\n```bash\nsh run_scripts/sevila/inference/nextqa_infer.sh\n```\n\n\n# Acknowledgments\nWe thank the developers of [LAVIS](https://github.com/salesforce/LAVIS), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [CLIP](https://github.com/openai/CLIP), [All-in-One](https://github.com/showlab/all-in-one), for their public code release.\n\n\n# Reference\nPlease cite our paper if you use our models in your works:\n\n\n```bibtex\n@inproceedings{yu2023self,\n  title   = {Self-Chained Image-Language Model for Video Localization and Question Answering},\n  author  = {Yu, Shoubin and Cho, Jaemin and Yadav, Prateek and Bansal, Mohit},\n  booktitle = {NeurIPS},\n  year    = {2023}\n}\n"
  },
  {
    "path": "app/__init__.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom PIL import Image\nimport requests\n\nimport streamlit as st\nimport torch\n\n\n@st.cache()\ndef load_demo_image():\n    img_url = (\n        \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg\"\n    )\n    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert(\"RGB\")\n    return raw_image\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ncache_root = \"/export/home/.cache/lavis/\"\n"
  },
  {
    "path": "app/calculate_coco_features.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom PIL import Image\nimport requests\nimport torch\n\nimport os\n\nfrom lavis.common.registry import registry\nfrom lavis.processors import *\nfrom lavis.models import *\nfrom lavis.common.utils import build_default_model\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef load_demo_image():\n    img_url = (\n        \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg\"\n    )\n    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert(\"RGB\")\n\n    return raw_image\n\n\ndef read_img(filepath):\n    raw_image = Image.open(filepath).convert(\"RGB\")\n\n    return raw_image\n\n\n# model\nmodel_url = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth\"\nfeature_extractor = BlipFeatureExtractor(pretrained=model_url)\n\nfeature_extractor.eval()\nfeature_extractor = feature_extractor.to(device)\n\n# preprocessors\nvis_processor = BlipImageEvalProcessor(image_size=224)\ntext_processor = BlipCaptionProcessor()\n\n# files to process\n# file_root = \"/export/home/.cache/lavis/coco/images/val2014\"\nfile_root = \"/export/home/.cache/lavis/coco/images/train2014\"\nfilepaths = os.listdir(file_root)\n\nprint(len(filepaths))\n\ncaption = \"dummy\"\n\npath2feat = dict()\nbsz = 256\n\nimages_in_batch = []\nfilepaths_in_batch = []\n\nfor i, filename in enumerate(filepaths):\n    if i % bsz == 0 and i > 0:\n        images_in_batch = torch.cat(images_in_batch, dim=0).to(device)\n        with torch.no_grad():\n            image_features = feature_extractor(\n                images_in_batch, caption, mode=\"image\", normalized=True\n            )[:, 0]\n\n        for filepath, image_feat in zip(filepaths_in_batch, image_features):\n            path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()\n\n        images_in_batch = []\n        filepaths_in_batch = []\n\n        print(len(path2feat), image_features.shape)\n    else:\n        filepath = os.path.join(file_root, filename)\n\n        image = read_img(filepath)\n        image = vis_processor(image).unsqueeze(0)\n\n        images_in_batch.append(image)\n        filepaths_in_batch.append(filepath)\n\ntorch.save(path2feat, \"path2feat_coco_train2014.pth\")\n"
  },
  {
    "path": "app/caption.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport streamlit as st\nfrom app import device, load_demo_image\nfrom app.utils import load_model_cache\nfrom lavis.processors import load_processor\nfrom PIL import Image\n\n\ndef app():\n    # ===== layout =====\n    model_type = st.sidebar.selectbox(\"Model:\", [\"BLIP_base\", \"BLIP_large\"])\n\n    sampling_method = st.sidebar.selectbox(\n        \"Sampling method:\", [\"Beam search\", \"Nucleus sampling\"]\n    )\n\n    st.markdown(\n        \"<h1 style='text-align: center;'>Image Description Generation</h1>\",\n        unsafe_allow_html=True,\n    )\n\n    instructions = \"\"\"Try the provided image or upload your own:\"\"\"\n    file = st.file_uploader(instructions)\n\n    use_beam = sampling_method == \"Beam search\"\n\n    col1, col2 = st.columns(2)\n\n    if file:\n        raw_img = Image.open(file).convert(\"RGB\")\n    else:\n        raw_img = load_demo_image()\n\n    col1.header(\"Image\")\n\n    w, h = raw_img.size\n    scaling_factor = 720 / w\n    resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))\n\n    col1.image(resized_image, use_column_width=True)\n    col2.header(\"Description\")\n\n    cap_button = st.button(\"Generate\")\n\n    # ==== event ====\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n\n    if cap_button:\n        if model_type.startswith(\"BLIP\"):\n            blip_type = model_type.split(\"_\")[1].lower()\n            model = load_model_cache(\n                \"blip_caption\",\n                model_type=f\"{blip_type}_coco\",\n                is_eval=True,\n                device=device,\n            )\n\n        img = vis_processor(raw_img).unsqueeze(0).to(device)\n        captions = generate_caption(\n            model=model, image=img, use_nucleus_sampling=not use_beam\n        )\n\n        col2.write(\"\\n\\n\".join(captions), use_column_width=True)\n\n\ndef generate_caption(\n    model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5\n):\n    samples = {\"image\": image}\n\n    captions = []\n    if use_nucleus_sampling:\n        for _ in range(5):\n            caption = model.generate(\n                samples,\n                use_nucleus_sampling=True,\n                max_length=max_length,\n                min_length=min_length,\n                top_p=0.9,\n            )\n            captions.append(caption[0])\n    else:\n        caption = model.generate(\n            samples,\n            use_nucleus_sampling=False,\n            num_beams=num_beams,\n            max_length=max_length,\n            min_length=min_length,\n        )\n        captions.append(caption[0])\n\n    return captions\n"
  },
  {
    "path": "app/classification.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport plotly.graph_objects as go\nimport requests\nimport streamlit as st\nimport torch\nfrom lavis.models import load_model\nfrom lavis.processors import load_processor\nfrom lavis.processors.blip_processors import BlipCaptionProcessor\nfrom PIL import Image\n\nfrom app import device, load_demo_image\nfrom app.utils import load_blip_itm_model\nfrom lavis.processors.clip_processors import ClipImageEvalProcessor\n\n\n@st.cache()\ndef load_demo_image(img_url=None):\n    if not img_url:\n        img_url = \"https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg\"\n    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert(\"RGB\")\n    return raw_image\n\n\n@st.cache(\n    hash_funcs={\n        torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()\n        .cpu()\n        .numpy()\n    },\n    allow_output_mutation=True,\n)\ndef load_model_cache(model_type, device):\n    if model_type == \"blip\":\n        model = load_model(\n            \"blip_feature_extractor\", model_type=\"base\", is_eval=True, device=device\n        )\n    elif model_type == \"albef\":\n        model = load_model(\n            \"albef_feature_extractor\", model_type=\"base\", is_eval=True, device=device\n        )\n    elif model_type == \"CLIP_ViT-B-32\":\n        model = load_model(\n            \"clip_feature_extractor\", \"ViT-B-32\", is_eval=True, device=device\n        )\n    elif model_type == \"CLIP_ViT-B-16\":\n        model = load_model(\n            \"clip_feature_extractor\", \"ViT-B-16\", is_eval=True, device=device\n        )\n    elif model_type == \"CLIP_ViT-L-14\":\n        model = load_model(\n            \"clip_feature_extractor\", \"ViT-L-14\", is_eval=True, device=device\n        )\n\n    return model\n\n\ndef app():\n    model_type = st.sidebar.selectbox(\n        \"Model:\",\n        [\"ALBEF\", \"BLIP_Base\", \"CLIP_ViT-B-32\", \"CLIP_ViT-B-16\", \"CLIP_ViT-L-14\"],\n    )\n    score_type = st.sidebar.selectbox(\"Score type:\", [\"Cosine\", \"Multimodal\"])\n\n    # ===== layout =====\n    st.markdown(\n        \"<h1 style='text-align: center;'>Zero-shot Classification</h1>\",\n        unsafe_allow_html=True,\n    )\n\n    instructions = \"\"\"Try the provided image or upload your own:\"\"\"\n    file = st.file_uploader(instructions)\n\n    st.header(\"Image\")\n    if file:\n        raw_img = Image.open(file).convert(\"RGB\")\n    else:\n        raw_img = load_demo_image()\n\n    st.image(raw_img)  # , use_column_width=True)\n\n    col1, col2 = st.columns(2)\n\n    col1.header(\"Categories\")\n\n    cls_0 = col1.text_input(\"category 1\", value=\"merlion\")\n    cls_1 = col1.text_input(\"category 2\", value=\"sky\")\n    cls_2 = col1.text_input(\"category 3\", value=\"giraffe\")\n    cls_3 = col1.text_input(\"category 4\", value=\"fountain\")\n    cls_4 = col1.text_input(\"category 5\", value=\"marina bay\")\n\n    cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]\n    cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]\n\n    if len(cls_names) != len(set(cls_names)):\n        st.error(\"Please provide unique class names\")\n        return\n\n    button = st.button(\"Submit\")\n\n    col2.header(\"Prediction\")\n\n    # ===== event =====\n\n    if button:\n        if model_type.startswith(\"BLIP\"):\n            text_processor = BlipCaptionProcessor(prompt=\"A picture of \")\n            cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]\n\n            if score_type == \"Cosine\":\n                vis_processor = load_processor(\"blip_image_eval\").build(image_size=224)\n                img = vis_processor(raw_img).unsqueeze(0).to(device)\n\n                feature_extractor = load_model_cache(model_type=\"blip\", device=device)\n\n                sample = {\"image\": img, \"text_input\": cls_prompt}\n\n                with torch.no_grad():\n                    image_features = feature_extractor.extract_features(\n                        sample, mode=\"image\"\n                    ).image_embeds_proj[:, 0]\n                    text_features = feature_extractor.extract_features(\n                        sample, mode=\"text\"\n                    ).text_embeds_proj[:, 0]\n                    sims = (image_features @ text_features.t())[\n                        0\n                    ] / feature_extractor.temp\n\n            else:\n                vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n                img = vis_processor(raw_img).unsqueeze(0).to(device)\n\n                model = load_blip_itm_model(device)\n\n                output = model(img, cls_prompt, match_head=\"itm\")\n                sims = output[:, 1]\n\n            sims = torch.nn.Softmax(dim=0)(sims)\n            inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]\n\n        elif model_type.startswith(\"ALBEF\"):\n            vis_processor = load_processor(\"blip_image_eval\").build(image_size=224)\n            img = vis_processor(raw_img).unsqueeze(0).to(device)\n\n            text_processor = BlipCaptionProcessor(prompt=\"A picture of \")\n            cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]\n\n            feature_extractor = load_model_cache(model_type=\"albef\", device=device)\n\n            sample = {\"image\": img, \"text_input\": cls_prompt}\n\n            with torch.no_grad():\n                image_features = feature_extractor.extract_features(\n                    sample, mode=\"image\"\n                ).image_embeds_proj[:, 0]\n                text_features = feature_extractor.extract_features(\n                    sample, mode=\"text\"\n                ).text_embeds_proj[:, 0]\n\n                st.write(image_features.shape)\n                st.write(text_features.shape)\n\n                sims = (image_features @ text_features.t())[0] / feature_extractor.temp\n\n            sims = torch.nn.Softmax(dim=0)(sims)\n            inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]\n\n        elif model_type.startswith(\"CLIP\"):\n            if model_type == \"CLIP_ViT-B-32\":\n                model = load_model_cache(model_type=\"CLIP_ViT-B-32\", device=device)\n            elif model_type == \"CLIP_ViT-B-16\":\n                model = load_model_cache(model_type=\"CLIP_ViT-B-16\", device=device)\n            elif model_type == \"CLIP_ViT-L-14\":\n                model = load_model_cache(model_type=\"CLIP_ViT-L-14\", device=device)\n            else:\n                raise ValueError(f\"Unknown model type {model_type}\")\n\n            if score_type == \"Cosine\":\n                # image_preprocess = ClipImageEvalProcessor(image_size=336)\n                image_preprocess = ClipImageEvalProcessor(image_size=224)\n                img = image_preprocess(raw_img).unsqueeze(0).to(device)\n\n                sample = {\"image\": img, \"text_input\": cls_names}\n\n                with torch.no_grad():\n                    clip_features = model.extract_features(sample)\n\n                    image_features = clip_features.image_embeds_proj\n                    text_features = clip_features.text_embeds_proj\n\n                    sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)\n                    inv_sims = sims.tolist()[::-1]\n            else:\n                st.warning(\"CLIP does not support multimodal scoring.\")\n                return\n\n        fig = go.Figure(\n            go.Bar(\n                x=inv_sims,\n                y=cls_names[::-1],\n                text=[\"{:.2f}\".format(s) for s in inv_sims],\n                orientation=\"h\",\n            )\n        )\n        fig.update_traces(\n            textfont_size=12,\n            textangle=0,\n            textposition=\"outside\",\n            cliponaxis=False,\n        )\n        col2.plotly_chart(fig, use_container_width=True)\n"
  },
  {
    "path": "app/dataset_browser.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport random\nfrom collections import OrderedDict\nfrom functools import reduce\nfrom tkinter import N\n\nimport streamlit as st\nfrom lavis.common.registry import registry\nfrom lavis.datasets.builders import dataset_zoo, load_dataset\nfrom lavis.datasets.builders.base_dataset_builder import load_dataset_config\nfrom PIL import Image\n\nIMAGE_LAYOUT = 3, 4\nVIDEO_LAYOUT = 1, 2\n\nPREV_STR = \"Prev\"\nNEXT_STR = \"Next\"\n\n\ndef sample_dataset(dataset, indices):\n    samples = [dataset.displ_item(idx) for idx in indices]\n\n    return samples\n\n\ndef get_concat_v(im1, im2):\n    margin = 5\n\n    canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))\n    canvas = Image.new(\"RGB\", canvas_size, \"White\")\n    canvas.paste(im1, (0, 0))\n    canvas.paste(im2, (im1.width + margin, 0))\n\n    return canvas\n\n\ndef resize_img_w(raw_img, new_w=224):\n    if isinstance(raw_img, list):\n        resized_imgs = [resize_img_w(img, 196) for img in raw_img]\n        # concatenate images\n        resized_image = reduce(get_concat_v, resized_imgs)\n    else:\n        w, h = raw_img.size\n        scaling_factor = new_w / w\n        resized_image = raw_img.resize(\n            (int(w * scaling_factor), int(h * scaling_factor))\n        )\n\n    return resized_image\n\n\ndef get_visual_key(dataset):\n    if \"image\" in dataset[0]:\n        return \"image\"\n    elif \"image0\" in dataset[0]:  # NLVR2 dataset\n        return \"image\"\n    elif \"video\" in dataset[0]:\n        return \"video\"\n    else:\n        raise ValueError(\"Visual key not found.\")\n\n\ndef gather_items(samples, exclude=[]):\n    gathered = []\n\n    for s in samples:\n        ns = OrderedDict()\n        for k in s.keys():\n            if k not in exclude:\n                ns[k] = s[k]\n\n        gathered.append(ns)\n\n    return gathered\n\n\n@st.cache(allow_output_mutation=True)\ndef load_dataset_cache(name):\n    return load_dataset(name)\n\n\ndef format_text(text):\n    md = \"\\n\\n\".join([f\"**{k}**: {v}\" for k, v in text.items()])\n\n    return md\n\n\ndef show_samples(dataset, offset=0, is_next=False):\n    visual_key = get_visual_key(dataset)\n\n    num_rows, num_cols = IMAGE_LAYOUT if visual_key == \"image\" else VIDEO_LAYOUT\n    n_samples = num_rows * num_cols\n\n    if not shuffle:\n        if is_next:\n            start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)\n        else:\n            start = max(0, int(start_idx) + offset - n_samples)\n\n        st.session_state.last_start = start\n        end = min(start + n_samples, len(dataset))\n\n        indices = list(range(start, end))\n    else:\n        indices = random.sample(range(len(dataset)), n_samples)\n    samples = sample_dataset(dataset, indices)\n\n    visual_info = (\n        iter([resize_img_w(s[visual_key]) for s in samples])\n        if visual_key == \"image\"\n        # else iter([s[visual_key] for s in samples])\n        else iter([s[\"file\"] for s in samples])\n    )\n    text_info = gather_items(samples, exclude=[\"image\", \"video\"])\n    text_info = iter([format_text(s) for s in text_info])\n\n    st.markdown(\n        \"\"\"<hr style=\"height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;\"/> \"\"\",\n        unsafe_allow_html=True,\n    )\n    for _ in range(num_rows):\n        with st.container():\n            for col in st.columns(num_cols):\n                # col.text(next(text_info))\n                # col.caption(next(text_info))\n                try:\n                    col.markdown(next(text_info))\n                    if visual_key == \"image\":\n                        col.image(next(visual_info), use_column_width=True, clamp=True)\n                    elif visual_key == \"video\":\n                        col.markdown(\n                            \"![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)\"\n                        )\n                except StopIteration:\n                    break\n\n            st.markdown(\n                \"\"\"<hr style=\"height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;\"/> \"\"\",\n                unsafe_allow_html=True,\n            )\n\n    st.session_state.n_display = n_samples\n\n\nif __name__ == \"__main__\":\n    st.set_page_config(\n        page_title=\"LAVIS Dataset Explorer\",\n        # layout=\"wide\",\n        initial_sidebar_state=\"expanded\",\n    )\n\n    dataset_name = st.sidebar.selectbox(\"Dataset:\", dataset_zoo.get_names())\n\n    function = st.sidebar.selectbox(\"Function:\", [\"Browser\"], index=0)\n\n    if function == \"Browser\":\n        shuffle = st.sidebar.selectbox(\"Shuffled:\", [True, False], index=0)\n\n        dataset = load_dataset_cache(dataset_name)\n        split = st.sidebar.selectbox(\"Split:\", dataset.keys())\n\n        dataset_len = len(dataset[split])\n        st.success(\n            f\"Loaded {dataset_name}/{split} with **{dataset_len}** records.  **Image/video directory**: {dataset[split].vis_root}\"\n        )\n\n        if \"last_dataset\" not in st.session_state:\n            st.session_state.last_dataset = dataset_name\n            st.session_state.last_split = split\n\n        if \"last_start\" not in st.session_state:\n            st.session_state.last_start = 0\n\n        if \"start_idx\" not in st.session_state:\n            st.session_state.start_idx = 0\n\n        if \"shuffle\" not in st.session_state:\n            st.session_state.shuffle = shuffle\n\n        if \"first_run\" not in st.session_state:\n            st.session_state.first_run = True\n        elif (\n            st.session_state.last_dataset != dataset_name\n            or st.session_state.last_split != split\n        ):\n            st.session_state.first_run = True\n\n            st.session_state.last_dataset = dataset_name\n            st.session_state.last_split = split\n        elif st.session_state.shuffle != shuffle:\n            st.session_state.shuffle = shuffle\n            st.session_state.first_run = True\n\n        if not shuffle:\n            n_col, p_col = st.columns([0.05, 1])\n\n            prev_button = n_col.button(PREV_STR)\n            next_button = p_col.button(NEXT_STR)\n\n        else:\n            next_button = st.button(NEXT_STR)\n\n        if not shuffle:\n            start_idx = st.sidebar.text_input(f\"Begin from (total {dataset_len})\", 0)\n\n            if not start_idx.isdigit():\n                st.error(f\"Input to 'Begin from' must be digits, found {start_idx}.\")\n            else:\n                if int(start_idx) != st.session_state.start_idx:\n                    st.session_state.start_idx = int(start_idx)\n                    st.session_state.last_start = int(start_idx)\n\n            if prev_button:\n                show_samples(\n                    dataset[split],\n                    offset=st.session_state.last_start - st.session_state.start_idx,\n                    is_next=False,\n                )\n\n        if next_button:\n            show_samples(\n                dataset[split],\n                offset=st.session_state.last_start - st.session_state.start_idx,\n                is_next=True,\n            )\n\n        if st.session_state.first_run:\n            st.session_state.first_run = False\n\n            show_samples(\n                dataset[split],\n                offset=st.session_state.last_start - st.session_state.start_idx,\n                is_next=True,\n            )\n"
  },
  {
    "path": "app/image_text_match.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport numpy as np\nimport streamlit as st\nimport torch\nfrom lavis.models.blip_models.blip_image_text_matching import compute_gradcam\nfrom lavis.processors import load_processor\nfrom PIL import Image\n\nfrom app import device, load_demo_image\nfrom app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model\n\n\ndef app():\n    model_type = st.sidebar.selectbox(\"Model:\", [\"BLIP_base\", \"BLIP_large\"])\n\n    if model_type.startswith(\"BLIP\"):\n        blip_type = model_type.split(\"_\")[1]\n        model = load_blip_itm_model(device, model_type=blip_type)\n\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n\n    st.markdown(\n        \"<h1 style='text-align: center;'>Image Text Matching</h1>\",\n        unsafe_allow_html=True,\n    )\n\n    values = list(range(1, 12))\n    default_layer_num = values.index(7)\n    layer_num = (\n        st.sidebar.selectbox(\"Layer number\", values, index=default_layer_num) - 1\n    )\n\n    instructions = \"\"\"Try the provided image or upload your own:\"\"\"\n    file = st.file_uploader(instructions)\n\n    col1, col2 = st.columns(2)\n    col1.header(\"Image\")\n    col2.header(\"GradCam\")\n    if file:\n        raw_img = Image.open(file).convert(\"RGB\")\n    else:\n        raw_img = load_demo_image()\n\n    w, h = raw_img.size\n    scaling_factor = 720 / w\n    resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))\n    col1.image(resized_image, use_column_width=True)\n\n    col3, col4 = st.columns(2)\n    col3.header(\"Text\")\n    user_question = col3.text_input(\n        \"Input your sentence!\", \"a woman sitting on the beach with a dog\"\n    )\n    submit_button = col3.button(\"Submit\")\n\n    col4.header(\"Matching score\")\n\n    if submit_button:\n        tokenizer = init_bert_tokenizer()\n\n        img = vis_processor(raw_img).unsqueeze(0).to(device)\n        text_processor = load_processor(\"blip_caption\").build()\n\n        qry = text_processor(user_question)\n\n        norm_img = np.float32(resized_image) / 255\n\n        qry_tok = tokenizer(qry, return_tensors=\"pt\").to(device)\n        gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)\n\n        avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)\n\n        col2.image(avg_gradcam, use_column_width=True, clamp=True)\n        # output = model(img, question)\n        itm_score = torch.nn.functional.softmax(output, dim=1)\n        new_title = (\n            '<p style=\"text-align: left; font-size: 25px;\">\\n{:.3f}%</p>'.format(\n                itm_score[0][1].item() * 100\n            )\n        )\n        col4.markdown(new_title, unsafe_allow_html=True)\n"
  },
  {
    "path": "app/main.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom app.multipage import MultiPage\nfrom app import vqa, caption\nfrom app import image_text_match as itm\nfrom app import text_localization as tl\nfrom app import multimodal_search as ms\nfrom app import classification as cl\n\n\nif __name__ == \"__main__\":\n    app = MultiPage()\n\n    app.add_page(\"Image Description Generation\", caption.app)\n    app.add_page(\"Multimodal Search\", ms.app)\n    app.add_page(\"Visual Question Answering\", vqa.app)\n    app.add_page(\"Image Text Matching\", itm.app)\n    app.add_page(\"Text Localization\", tl.app)\n    app.add_page(\"Classification\", cl.app)\n    app.run()\n"
  },
  {
    "path": "app/multimodal_search.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\n\nimport numpy as np\nimport streamlit as st\nimport torch\nimport torch.nn.functional as F\nfrom app import cache_root, device\nfrom app.utils import (\n    getAttMap,\n    init_bert_tokenizer,\n    load_blip_itm_model,\n    read_img,\n    resize_img,\n)\nfrom lavis.models import load_model\nfrom lavis.processors import load_processor\n\n\n@st.cache(\n    hash_funcs={\n        torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()\n        .cpu()\n        .numpy()\n    },\n    allow_output_mutation=True,\n)\ndef load_feat():\n    from lavis.common.utils import download_url\n\n    dirname = os.path.join(os.path.dirname(__file__), \"assets\")\n    filename = \"path2feat_coco_train2014.pth\"\n    filepath = os.path.join(dirname, filename)\n    url = \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth\"\n\n    if not os.path.exists(filepath):\n        download_url(url=url, root=dirname, filename=\"path2feat_coco_train2014.pth\")\n\n    path2feat = torch.load(filepath)\n    paths = sorted(path2feat.keys())\n\n    all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)\n\n    return path2feat, paths, all_img_feats\n\n\n@st.cache(\n    hash_funcs={\n        torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()\n        .cpu()\n        .numpy()\n    },\n    allow_output_mutation=True,\n)\ndef load_feature_extractor_model(device):\n    model_url = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth\"\n\n    model = load_model(\n        \"blip_feature_extractor\", model_type=\"base\", is_eval=True, device=device\n    )\n    model.load_from_pretrained(model_url)\n\n    return model\n\n\ndef app():\n    # === layout ===\n    model_type = st.sidebar.selectbox(\"Model:\", [\"BLIP_base\", \"BLIP_large\"])\n    file_root = os.path.join(cache_root, \"coco/images/train2014/\")\n\n    values = [12, 24, 48]\n    default_layer_num = values.index(24)\n    num_display = st.sidebar.selectbox(\n        \"Number of images:\", values, index=default_layer_num\n    )\n    show_gradcam = st.sidebar.selectbox(\"Show GradCam:\", [True, False], index=1)\n    itm_ranking = st.sidebar.selectbox(\"Multimodal re-ranking:\", [True, False], index=0)\n\n    # st.title('Multimodal Search')\n    st.markdown(\n        \"<h1 style='text-align: center;'>Multimodal Search</h1>\", unsafe_allow_html=True\n    )\n\n    # === event ===\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n    text_processor = load_processor(\"blip_caption\")\n\n    user_question = st.text_input(\n        \"Search query\", \"A dog running on the grass.\", help=\"Type something to search.\"\n    )\n    user_question = text_processor(user_question)\n    feature_extractor = load_feature_extractor_model(device)\n\n    # ======= ITC =========\n    sample = {\"text_input\": user_question}\n\n    with torch.no_grad():\n        text_feature = feature_extractor.extract_features(\n            sample, mode=\"text\"\n        ).text_embeds_proj[0, 0]\n\n        path2feat, paths, all_img_feats = load_feat()\n        all_img_feats.to(device)\n        all_img_feats = F.normalize(all_img_feats, dim=1)\n\n        num_cols = 4\n        num_rows = int(num_display / num_cols)\n\n        similarities = text_feature @ all_img_feats.T\n        indices = torch.argsort(similarities, descending=True)[:num_display]\n\n    top_paths = [paths[ind.detach().cpu().item()] for ind in indices]\n    sorted_similarities = [similarities[idx] for idx in indices]\n    filenames = [os.path.join(file_root, p) for p in top_paths]\n\n    # ========= ITM and GradCam ==========\n    bsz = 4  # max number of images to avoid cuda oom\n    if model_type.startswith(\"BLIP\"):\n        blip_type = model_type.split(\"_\")[1]\n\n    itm_model = load_blip_itm_model(device, model_type=blip_type)\n\n    tokenizer = init_bert_tokenizer()\n    queries_batch = [user_question] * bsz\n    queries_tok_batch = tokenizer(queries_batch, return_tensors=\"pt\").to(device)\n\n    num_batches = int(num_display / bsz)\n\n    avg_gradcams = []\n    all_raw_images = []\n    itm_scores = []\n\n    for i in range(num_batches):\n        filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]\n        raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)\n        gradcam, itm_output = compute_gradcam_batch(\n            itm_model, images, queries_batch, queries_tok_batch\n        )\n\n        all_raw_images.extend([resize_img(r_img) for r_img in raw_images])\n        norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]\n\n        for norm_img, grad_cam in zip(norm_imgs, gradcam):\n            avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)\n            avg_gradcams.append(avg_gradcam)\n\n        with torch.no_grad():\n            itm_score = torch.nn.functional.softmax(itm_output, dim=1)\n\n        itm_scores.append(itm_score)\n\n    # ========= ITM re-ranking =========\n    itm_scores = torch.cat(itm_scores)[:, 1]\n    if itm_ranking:\n        itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)\n\n        avg_gradcams_sorted = []\n        all_raw_images_sorted = []\n        for idx in indices:\n            avg_gradcams_sorted.append(avg_gradcams[idx])\n            all_raw_images_sorted.append(all_raw_images[idx])\n\n        avg_gradcams = avg_gradcams_sorted\n        all_raw_images = all_raw_images_sorted\n\n    if show_gradcam:\n        images_to_show = iter(avg_gradcams)\n    else:\n        images_to_show = iter(all_raw_images)\n\n    for _ in range(num_rows):\n        with st.container():\n            for col in st.columns(num_cols):\n                col.image(next(images_to_show), use_column_width=True, clamp=True)\n\n\ndef read_and_process_images(image_paths, vis_processor):\n    raw_images = [read_img(path) for path in image_paths]\n    images = [vis_processor(r_img) for r_img in raw_images]\n    images_tensors = torch.stack(images).to(device)\n\n    return raw_images, images_tensors\n\n\ndef compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):\n    model.text_encoder.base_model.base_model.encoder.layer[\n        block_num\n    ].crossattention.self.save_attention = True\n\n    output = model({\"image\": visual_input, \"text_input\": text_input}, match_head=\"itm\")\n    loss = output[:, 1].sum()\n\n    model.zero_grad()\n    loss.backward()\n    with torch.no_grad():\n        mask = tokenized_text.attention_mask.view(\n            tokenized_text.attention_mask.size(0), 1, -1, 1, 1\n        )  # (bsz,1,token_len, 1,1)\n        token_length = mask.sum() - 2\n        token_length = token_length.cpu()\n        # grads and cams [bsz, num_head, seq_len, image_patch]\n        grads = model.text_encoder.base_model.base_model.encoder.layer[\n            block_num\n        ].crossattention.self.get_attn_gradients()\n        cams = model.text_encoder.base_model.base_model.encoder.layer[\n            block_num\n        ].crossattention.self.get_attention_map()\n\n        # assume using vit large with 576 num image patch\n        cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask\n        grads = (\n            grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)\n            * mask\n        )\n\n        gradcam = cams * grads\n        # [enc token gradcam, average gradcam across token, gradcam for individual token]\n        # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))\n        gradcam = gradcam.mean(1).cpu().detach()\n        gradcam = (\n            gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length\n        )\n\n    return gradcam, output\n"
  },
  {
    "path": "app/multipage.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n\"\"\"\nThis file is the framework for generating multiple Streamlit applications\nthrough an object oriented framework.\n\"\"\"\n\n# Import necessary libraries\nimport streamlit as st\n\n# Define the multipage class to manage the multiple apps in our program\nclass MultiPage:\n    \"\"\"Framework for combining multiple streamlit applications.\"\"\"\n\n    def __init__(self) -> None:\n        \"\"\"Constructor class to generate a list which will store all our applications as an instance variable.\"\"\"\n        self.pages = []\n\n    def add_page(self, title, func) -> None:\n        \"\"\"Class Method to Add pages to the project\n        Args:\n            title ([str]): The title of page which we are adding to the list of apps\n\n            func: Python function to render this page in Streamlit\n        \"\"\"\n\n        self.pages.append({\"title\": title, \"function\": func})\n\n    def run(self):\n        # Drodown to select the page to run\n        page = st.sidebar.selectbox(\n            \"Navigation\", self.pages, format_func=lambda page: page[\"title\"]\n        )\n\n        # run the app function\n        page[\"function\"]()\n"
  },
  {
    "path": "app/text_localization.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport math\n\nimport numpy as np\nimport streamlit as st\nfrom lavis.models.blip_models.blip_image_text_matching import compute_gradcam\nfrom lavis.processors import load_processor\nfrom PIL import Image\n\nfrom app import device, load_demo_image\nfrom app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model\n\n\ndef app():\n    model_type = st.sidebar.selectbox(\"Model:\", [\"BLIP_base\", \"BLIP_large\"])\n\n    values = list(range(1, 12))\n    default_layer_num = values.index(7)\n    layer_num = (\n        st.sidebar.selectbox(\"Layer number\", values, index=default_layer_num) - 1\n    )\n\n    st.markdown(\n        \"<h1 style='text-align: center;'>Text Localization</h1>\", unsafe_allow_html=True\n    )\n\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n    text_processor = load_processor(\"blip_caption\")\n\n    tokenizer = init_bert_tokenizer()\n\n    instructions = \"Try the provided image and text or use your own ones.\"\n    file = st.file_uploader(instructions)\n\n    query = st.text_input(\n        \"Try a different input.\", \"A girl playing with her dog on the beach.\"\n    )\n\n    submit_button = st.button(\"Submit\")\n\n    col1, col2 = st.columns(2)\n\n    if file:\n        raw_img = Image.open(file).convert(\"RGB\")\n    else:\n        raw_img = load_demo_image()\n\n    col1.header(\"Image\")\n    w, h = raw_img.size\n    scaling_factor = 720 / w\n    resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))\n    col1.image(resized_image, use_column_width=True)\n\n    col2.header(\"GradCam\")\n\n    if submit_button:\n        if model_type.startswith(\"BLIP\"):\n            blip_type = model_type.split(\"_\")[1]\n            model = load_blip_itm_model(device, model_type=blip_type)\n\n        img = vis_processor(raw_img).unsqueeze(0).to(device)\n        qry = text_processor(query)\n\n        qry_tok = tokenizer(qry, return_tensors=\"pt\").to(device)\n\n        norm_img = np.float32(resized_image) / 255\n\n        gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)\n\n        avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)\n        col2.image(avg_gradcam, use_column_width=True, clamp=True)\n\n        num_cols = 4.0\n        num_tokens = len(qry_tok.input_ids[0]) - 2\n\n        num_rows = int(math.ceil(num_tokens / num_cols))\n\n        gradcam_iter = iter(gradcam[0][2:-1])\n        token_id_iter = iter(qry_tok.input_ids[0][1:-1])\n\n        for _ in range(num_rows):\n            with st.container():\n                for col in st.columns(int(num_cols)):\n                    token_id = next(token_id_iter, None)\n                    if not token_id:\n                        break\n                    gradcam_img = next(gradcam_iter)\n\n                    word = tokenizer.decode([token_id])\n                    gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True)\n\n                    new_title = (\n                        '<p style=\"text-align: center; font-size: 25px;\">{}</p>'.format(\n                            word\n                        )\n                    )\n                    col.markdown(new_title, unsafe_allow_html=True)\n                    # st.image(image, channels=\"BGR\")\n                    col.image(gradcam_todraw, use_column_width=True, clamp=True)\n"
  },
  {
    "path": "app/utils.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport numpy as np\nimport streamlit as st\nimport torch\nfrom lavis.models import BlipBase, load_model\nfrom matplotlib import pyplot as plt\nfrom PIL import Image\nfrom scipy.ndimage import filters\nfrom skimage import transform as skimage_transform\n\n\ndef resize_img(raw_img):\n    w, h = raw_img.size\n    scaling_factor = 240 / w\n    resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))\n    return resized_image\n\n\ndef read_img(filepath):\n    raw_image = Image.open(filepath).convert(\"RGB\")\n\n    return raw_image\n\n\n@st.cache(\n    hash_funcs={\n        torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()\n        .cpu()\n        .numpy()\n    },\n    allow_output_mutation=True,\n)\ndef load_model_cache(name, model_type, is_eval, device):\n    return load_model(name, model_type, is_eval, device)\n\n\n@st.cache(allow_output_mutation=True)\ndef init_bert_tokenizer():\n    tokenizer = BlipBase.init_tokenizer()\n    return tokenizer\n\n\ndef getAttMap(img, attMap, blur=True, overlap=True):\n    attMap -= attMap.min()\n    if attMap.max() > 0:\n        attMap /= attMap.max()\n    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode=\"constant\")\n    if blur:\n        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))\n        attMap -= attMap.min()\n        attMap /= attMap.max()\n    cmap = plt.get_cmap(\"jet\")\n    attMapV = cmap(attMap)\n    attMapV = np.delete(attMapV, 3, 2)\n    if overlap:\n        attMap = (\n            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img\n            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV\n        )\n    return attMap\n\n\n@st.cache(\n    hash_funcs={\n        torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()\n        .cpu()\n        .numpy()\n    },\n    allow_output_mutation=True,\n)\ndef load_blip_itm_model(device, model_type=\"base\"):\n    model = load_model(\n        \"blip_image_text_matching\", model_type, is_eval=True, device=device\n    )\n    return model\n"
  },
  {
    "path": "app/vqa.py",
    "content": "\"\"\"\n # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport streamlit as st\nfrom app import load_demo_image, device\nfrom app.utils import load_model_cache\nfrom lavis.processors import load_processor\nfrom PIL import Image\n\n\ndef app():\n    model_type = st.sidebar.selectbox(\"Model:\", [\"BLIP\"])\n\n    # ===== layout =====\n    st.markdown(\n        \"<h1 style='text-align: center;'>Visual Question Answering</h1>\",\n        unsafe_allow_html=True,\n    )\n\n    instructions = \"\"\"Try the provided image or upload your own:\"\"\"\n    file = st.file_uploader(instructions)\n\n    col1, col2 = st.columns(2)\n\n    col1.header(\"Image\")\n    if file:\n        raw_img = Image.open(file).convert(\"RGB\")\n    else:\n        raw_img = load_demo_image()\n\n    w, h = raw_img.size\n    scaling_factor = 720 / w\n    resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))\n\n    col1.image(resized_image, use_column_width=True)\n    col2.header(\"Question\")\n\n    user_question = col2.text_input(\"Input your question!\", \"What are objects there?\")\n    qa_button = st.button(\"Submit\")\n\n    col2.header(\"Answer\")\n\n    # ===== event =====\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=480)\n    text_processor = load_processor(\"blip_question\").build()\n\n    if qa_button:\n        if model_type.startswith(\"BLIP\"):\n            model = load_model_cache(\n                \"blip_vqa\", model_type=\"vqav2\", is_eval=True, device=device\n            )\n\n            img = vis_processor(raw_img).unsqueeze(0).to(device)\n            question = text_processor(user_question)\n\n            vqa_samples = {\"image\": img, \"text_input\": [question]}\n            answers = model.predict_answers(vqa_samples, inference_method=\"generate\")\n\n            col2.write(\"\\n\".join(answers), use_column_width=True)\n"
  },
  {
    "path": "app.py",
    "content": "import gradio as gr\nimport os\nimport torch\nfrom torchvision import transforms\nfrom lavis.processors import transforms_video\nfrom lavis.datasets.data_utils import load_video_demo\nfrom lavis.processors.blip_processors import ToUint8, ToTHWC\nfrom lavis.models.sevila_models.sevila import SeViLA\nfrom typing import Optional\nimport warnings\n# model config\nimg_size = 224\nnum_query_token = 32\nt5_model = 'google/flan-t5-xl'\ndrop_path_rate = 0\nuse_grad_checkpoint = False\nvit_precision = \"fp16\"\nfreeze_vit = True\nprompt = ''\nmax_txt_len = 77\nanswer_num = 5\napply_lemmatizer = False\ntask = 'freeze_loc_freeze_qa_vid'\n\n# prompt\nLOC_propmpt = 'Does the information within the frame provide the necessary details to accurately answer the given question?'\nQA_prompt = 'Considering the information presented in the frame, select the correct answer from the options.'\n\n# processors config\nmean = (0.48145466, 0.4578275, 0.40821073)\nstd = (0.26862954, 0.26130258, 0.27577711)\nnormalize = transforms.Normalize(mean, std)\nimage_size = img_size\ntransform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize])\n\nprint('Model Loading \\nLoading the SeViLA model can take a few minutes (typically 2-3).')\nsevila = SeViLA(\n    img_size=img_size,\n    drop_path_rate=drop_path_rate,\n    use_grad_checkpoint=use_grad_checkpoint,\n    vit_precision=vit_precision,\n    freeze_vit=freeze_vit,\n    num_query_token=num_query_token,\n    t5_model=t5_model,\n    prompt=prompt,\n    max_txt_len=max_txt_len,\n    apply_lemmatizer=apply_lemmatizer,\n    frame_num=4,\n    answer_num=answer_num,\n    task=task,\n        )\n\nsevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth')\nprint('Model Loaded')\n\nANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'}\n\n# os.mkdir('video')\n\ndef sevila_demo(video, \n    question, \n    option1, option2, option3, \n    video_frame_num, \n    keyframe_num):\n    \n    if torch.cuda.is_available():\n        device = 0\n    else:\n        device = 'cpu'\n        \n    global sevila \n    if device == \"cpu\":\n        sevila = sevila.float()\n    else:\n        sevila = sevila.to(int(device))\n        \n    vpath = video \n    raw_clip, indice, fps, vlen = load_video_demo(\n        video_path=vpath,\n        n_frms=int(video_frame_num),\n        height=image_size,\n        width=image_size,\n        sampling=\"uniform\",\n        clip_proposal=None\n    )\n    clip = transform(raw_clip.permute(1,0,2,3))\n    clip = clip.float().to(int(device))\n    clip = clip.unsqueeze(0)\n    # check\n    if option1[-1] != '.':\n        option1 += '.'\n    if option2[-1] != '.':\n        option2 += '.' \n    if option3[-1] != '.':\n        option3 += '.'\n    option_dict = {0:option1, 1:option2, 2:option3}\n    options = 'Option A:{} Option B:{} Option C:{}'.format(option1, option2, option3)\n    text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt\n    text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt\n    \n    out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num))\n    # print(out)\n    answer_id = out['output_text'][0]\n    answer = option_dict[answer_id]\n    select_index = out['frame_idx'][0]\n    # images = [] \n    keyframes = []\n    timestamps =[]\n    \n    # print('raw_clip', len(raw_clip))\n    # for j in range(int(video_frame_num)):\n    #     image = raw_clip[:, j, :, :].int()\n    #     image = image.permute(1, 2, 0).numpy() \n    #     images.append(image)\n    \n    video_len = vlen/fps # seconds\n    \n    for i in select_index:\n        image = raw_clip[:, i, :, :].int()\n        image = image.permute(1, 2, 0).numpy() \n        keyframes.append(image)\n        select_i = indice[i]\n        time = round((select_i / vlen) * video_len, 2)\n        timestamps.append(str(time)+'s')\n    \n    gr.components.Gallery(keyframes)\n    #gr.components.Gallery(images)\n    timestamps_des = ''\n    for i in range(len(select_index)):\n        timestamps_des += 'Keyframe {}: {} \\n'.format(str(i+1), timestamps[i])\n    \n    return keyframes, timestamps_des, answer\n\nwith gr.Blocks(title=\"SeViLA demo\") as demo:\n    description = \"\"\"<p style=\"text-align: center; font-weight: bold;\">\n        <span style=\"font-size: 28px\">Self-Chained Image-Language Model for Video Localization and Question Answering</span>\n        <br>\n        <span style=\"font-size: 18px\" id=\"author-info\">\n            <a href=\"https://yui010206.github.io/\" target=\"_blank\">Shoubin Yu</a>, \n            <a href=\"https://j-min.io/\" target=\"_blank\">Jaemin Cho</a>, \n            <a href=\"https://prateek-yadav.github.io/\" target=\"_blank\">Prateek Yadav</a>, \n            <a href=\"https://www.cs.unc.edu/~mbansal/\" target=\"_blank\">Mohit Bansal</a>\n        </span> \n        <br>\n        <span style=\"font-size: 18px\" id=\"paper-info\">\n            [<a href=\"https://github.com/Yui010206/SeViLA\" target=\"_blank\">GitHub</a>]\n            [<a href=\"https://arxiv.org/abs/2305.06988\" target=\"_blank\">Paper</a>]\n        </span>\n    </p>\n    <p>\n        To locate keyframes in a video and answer question, please:\n        <br>\n        (1) upolad your video; (2) write your question/options and set # video frame/# keyframe; (3) click Locate and Answer!\n        <br>\n        Just a heads up - loading the SeViLA model can take a few minutes (typically 2-3), and running examples requires about 12GB of memory.\n        <br>\n        We've got you covered! We've provided some example videos and questions below to help you get started. Feel free to try out SeViLA with these!\n    </p>\n    \"\"\"\n    gr.HTML(description)\n    with gr.Row():\n        with gr.Column(scale=1, min_width=600):\n            video = gr.Video(label='Video') \n            question = gr.Textbox(placeholder=\"Why did the two ladies put their hands above their eyes while staring out?\", label='Question')\n            with gr.Row():\n                option1 = gr.Textbox(placeholder=\"practicing cheer\", label='Option 1')\n                option2 = gr.Textbox(placeholder=\"posing for photo\", label='Option 2')\n                option3 = gr.Textbox(placeholder=\"to see better\", label='Option 3')\n            with gr.Row():\n                video_frame_num = gr.Textbox(placeholder=32, label='# Video Frame')\n                keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe') \n            # device = gr.Textbox(placeholder=0, label='Device') \n            gen_btn = gr.Button(value='Locate and Answer!')\n        with gr.Column(scale=1, min_width=600): \n            keyframes = gr.Gallery(\n                label=\"Keyframes\", show_label=False, elem_id=\"gallery\",\n                ).style(columns=[4], rows=[1], object_fit=\"contain\", max_width=100, max_height=100)\n            #keyframes = gr.Gallery(label='Keyframes')\n            timestamps = gr.outputs.Textbox(label=\"Keyframe Timestamps\")\n            answer = gr.outputs.Textbox(label=\"Output Answer\")\n        \n        gen_btn.click(\n            sevila_demo,\n            inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],\n            outputs=[keyframes, timestamps, answer],\n            queue=True\n        )\n        #demo = gr.Interface(sevila_demo,\n        #     inputs=[gr.Video(), question, option1, option2, option3, video_frame_num, keyframe_num, device],\n        #     outputs=['gallery', timestamps, answer],\n        #     examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer.', 'play ball.', 'to see better.', 32, 4, 0],\n        #               ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose.' , 'bend down.','raised their hands.', 32, 4, 0],\n        #               ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen.' , 'the dining room.','the bathroom.', 32, 4, 0]]\n        #     )\n    with gr.Column():\n        gr.Examples(\n            inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],\n            outputs=[keyframes, timestamps, answer],\n            fn=sevila_demo,\n            examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer', 'to place wreaths', 'to see better', 32, 4],\n                      ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose' , 'bend down','raised their hands', 32, 4],\n                      ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the bedroom' , 'the bathroom','the kitchen', 32, 4]],\n            cache_examples=False,\n        )\ndemo.queue(concurrency_count=1, api_open=False)          \ndemo.launch(share=False) "
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/benchmark.rst",
    "content": "Benchmark\n############\n\nWe provide scripts for evaluating and training models on task datasets. The following benchmark results are included for reference.\n\n\nALBEF\n*******\n.. list-table::\n   :widths: 30 80 20\n\n   * - **Pretraining**\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/pretrain.sh>`__\n   * -\n     - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)\n     -\n   * -\n     - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)\n     -\n   * -\n     - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)\n     -\n   * -\n     - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)\n     -\n\n.. list-table::\n   :widths: 30 40 20 20 20 30 30\n   :header-rows: 1\n\n   * -\n     - **Retrieval**\n     - **R1**\n     - **R5**\n     - **R10**\n     - **Training**\n     - **Evaluation**\n   * - TR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 77.6\n     - 94.1\n     - 97.2\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__\n   * - IR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 61.0\n     - 84.5\n     - 90.7\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__\n   * - TR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 77.6\n     - 94.1\n     - 97.2\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__\n   * - IR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 61.0\n     - 84.5\n     - 90.7\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__\n\n\n.. list-table::\n   :widths: 20 20 20 20 20\n   :header-rows: 1\n\n   * - **VQA**\n     - **test-dev**\n     - **test-std/test**\n     - **Training**\n     - **Evaluation**\n   * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 76.35\n     - 76.54\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__\n   * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - NA\n     - 54.7 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_okvqa_albef.sh>`__\n     - NA\n   * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 54.5\n     - NA\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_aokvqa_albef.sh>`__\n     - NA\n\n  \n.. list-table::\n   :widths: 20 20 20 20 20\n   :header-rows: 1\n\n   * - **Multimodal Classification**\n     - **val**\n     - **test**\n     - **Training**\n     - **Evaluation**\n   * - SNLI-VE (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 80.60\n     - 81.04\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_ve_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_ve.sh>`__\n   * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 82.47 \n     - 82.91 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_nlvr_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_nlvr.sh>`__\n  \nBLIP\n*******\n.. list-table::\n   :widths: 30 80 20\n\n   * - **Pretraining (14M)**\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/pretrain.sh>`__\n   * -\n     - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)\n     -\n   * -\n     - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)\n     -\n   * -\n     - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)\n     -\n   * -\n     - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)\n     -\n\n.. list-table::\n   :widths: 30 40 20 20 20 30 30\n   :header-rows: 1\n\n   * - **Tasks**\n     - **Retrieval**\n     - **R1**\n     - **R5**\n     - **R10**\n     - **Training**\n     - **Evaluation**\n   * - TR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 82.0\n     - 95.8\n     - 98.1\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__\n   * - IR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 64.5\n     - 86.0\n     - 91.7\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__\n   * - TR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 96.9\n     - 99.9\n     - 100.0\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__\n   * - IR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 87.5\n     - 97.6\n     - 98.9\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__\n\n\n.. list-table::\n   :widths: 20 20 20 20 20\n   :header-rows: 1\n\n   * - **VQA**\n     - **test-dev**\n     - **test-std/test**\n     - **Training**\n     - **Evaluation**\n   * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 78.23\n     - 78.29\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__\n   * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - NA\n     - 55.4 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_okvqa.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_okvqa.sh>`__\n   * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 56.2\n     - 50.1 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_aokvqa.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_aokvqa.sh>`__\n\n\n.. list-table::\n   :widths: 20 20 20 20 20 20\n   :header-rows: 1\n\n   * - **Image Captioning**\n     - **BLEU@4**\n     - **CIDEr**\n     - **SPICE**\n     - **Training**\n     - **Evaluation**\n   * - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 39.9\n     - 133.5\n     - 23.7\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_caption_coco.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_coco_cap.sh>`__\n   * - NoCaps (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_nocaps.py>`__)\n     - 31.9\n     - 109.1\n     - 14.7\n     - NA\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nocaps.sh>`__\n\n\n.. list-table::\n   :widths: 20 20 20 20 20\n   :header-rows: 1\n\n   * - **Multimodal Classification**\n     - **val**\n     - **test**\n     - **Training**\n     - **Evaluation**\n   * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 82.48\n     - 83.25\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_nlvr.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nlvr.sh>`__\n\nCLIP\n*******\n.. list-table::\n   :widths: 30 40 20 20 20 30\n   :header-rows: 1\n\n   * - **Tasks**\n     - **Retrieval (Zero-shot)**\n     - **R1**\n     - **R5**\n     - **R10**\n     - **Evaluation**\n   * - TR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 57.2\n     - 80.5\n     - 87.8\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__\n   * - IR\n     - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)\n     - 36.5\n     - 60.8\n     - 71.0\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__\n   * - TR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 86.5\n     - 98.0\n     - 99.1\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__\n   * - IR\n     - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)\n     - 67.0\n     - 88.9\n     - 93.3\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__\n\n.. list-table::\n   :widths: 20 20 20\n   :header-rows: 1\n\n   * - **Multimodal Classification**\n     - **val**\n     - **Evaluation**\n   * - ImageNet \n     - 76.5 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_zs_imnet.sh>`__\n\n\nALPRO\n*******\n.. list-table::\n   :widths: 30 40 20 20 20 20 30\n   :header-rows: 1\n\n   * - **Tasks**\n     - **Retrieval**\n     - **R1**\n     - **R5**\n     - **R10**\n     - **Training**\n     - **Evaluation**\n   * - TR\n     - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)\n     - 33.2\n     - 60.5 \n     - 71.7 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__\n   * - VR\n     - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)\n     - 33.8\n     - 61.4\n     - 72.7\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__\n   * - TR\n     - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)\n     - 38.8 \n     - 66.4\n     - 76.8\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__\n   * - VR\n     - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)\n     - 36.6\n     - 67.5\n     - 77.9\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__\n\n.. list-table::\n   :widths: 20 20 20 20\n   :header-rows: 1\n\n   * - **Video QA**\n     - **test**\n     - **Training**\n     - **Evaluation**\n   * - MSRVTT \n     - 42.1 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_qa.sh>`__\n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_qa.sh>`__\n   * - MSVD \n     - 46.0 \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msvd_qa.sh>`__ \n     - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msvd_qa.sh>`__"
  },
  {
    "path": "docs/build_docs.sh",
    "content": "#!/bin/bash\nset -euo pipefail\n\n# Change to root directory of repo\nDIRNAME=$(cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\ncd \"${DIRNAME}/..\"\n\n# # Set up virtual environment\npip3 install setuptools wheel virtualenv\nif [ ! -d venv ]; then\n  rm -f venv\n  virtualenv venv\nfi\nsource venv/bin/activate\n\n# # Get current git branch & stash unsaved changes\nGIT_BRANCH=$(git branch --show-current)\nif [ -z \"${GIT_BRANCH}\" ]; then\n    GIT_BRANCH=\"main\"\nfi\ngit stash\n\n# Set up exit handler to restore git state & delete temp branches\n# function exit_handler {\n#     git reset --hard\n#     git checkout \"${GIT_BRANCH}\" --\n#     git stash pop || true\n#     for version in $(git tag --list 'v[0-9]*'); do\n#         branch=\"${version}_local_docs_only\"\n#         if git show-ref --verify --quiet \"refs/heads/$branch\"; then\n#             git branch -D \"$branch\"\n#         fi\n#     done\n# }\n# trap exit_handler EXIT\n\n# Clean up build directory and install Sphinx requirements\npip3 install -r \"${DIRNAME}/requirements.txt\"\nsphinx-build -M clean \"${DIRNAME}\" \"${DIRNAME}/_build\"\n\n# Build API docs for current head\nexport current_version=\"latest\"\npip3 install \".\"\nsphinx-build -b html \"${DIRNAME}\" \"${DIRNAME}/_build/html/${current_version}\" -W --keep-going\nrm -rf \"${DIRNAME}/_build/html/${current_version}/.doctrees\"\n#pip3 uninstall -y omnixai\n\n# Install all previous released versions\n# and use them to build the appropriate API docs.\n# Uninstall after we're done with each one.\n# versions=()\n# checkout_files=(\"${DIRNAME}/*.rst\" \"lavis\" \"tutorials\" \"setup.py\")\n# for version in $(git tag --list 'v[0-9]*'); do\n#     versions+=(\"$version\")\n#     git checkout -b \"${version}_local_docs_only\"\n#     for f in $(git diff --name-only --diff-filter=A \"tags/${version}\" \"${DIRNAME}/*.rst\"); do\n#         git rm \"$f\"\n#     done\n#     git checkout \"tags/${version}\" -- \"${checkout_files[@]}\"\n#     export current_version=${version}\n#     pip3 install \".[all]\"\n#     sphinx-build -b html \"${DIRNAME}\" \"${DIRNAME}/_build/html/${current_version}\" -W --keep-going\n#     rm -rf \"${DIRNAME}/_build/html/${current_version}/.doctrees\"\n#     #pip3 uninstall -y omnixai\n#     git reset --hard\n#     git checkout \"${GIT_BRANCH}\" --\n# done\n\n# Determine the latest stable version if there is one\n# if (( ${#versions[@]} > 0 )); then\n#   stable_hash=$(git rev-list --tags --max-count=1)\n#   stable_version=$(git describe --tags \"$stable_hash\")\n#   export stable_version\n# else\nexport stable_version=\"latest\"\n# fi\n\n# Create dummy HTML's for the stable version in the base directory\nwhile read -r filename; do\n    filename=$(echo \"$filename\" | sed \"s/\\.\\///\")\n    n_sub=$(echo \"$filename\" | (grep -o \"/\" || true) | wc -l)\n    prefix=\"\"\n    for (( i=0; i<n_sub; i++ )); do\n        prefix+=\"../\"\n    done\n    url=\"${prefix}${stable_version}/$filename\"\n    mkdir -p \"${DIRNAME}/_build/html/$(dirname \"$filename\")\"\n    cat > \"${DIRNAME}/_build/html/$filename\" <<EOF\n<!DOCTYPE html>\n<html>\n   <head>\n      <title>LAVIS Documentation</title>\n      <meta http-equiv = \"refresh\" content=\"0; url='$url'\" />\n   </head>\n   <body>\n      <p>Please wait while you're redirected to our <a href=\"$url\">documentation</a>.</p>\n   </body>\n</html>\nEOF\ndone < <(cd \"${DIRNAME}/_build/html/$stable_version\" && find . -name \"*.html\")\necho \"Finished writing to _build/html.\""
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\n# import os\n# import sys\n# sys.path.insert(0, os.path.abspath('.'))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = \"LAVIS\"\ncopyright = \"2022, salesforce.com inc.\"\nauthor = (\n    \"Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi\"\n)\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\"nbsphinx\"]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = []\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\n# html_theme = \"alabaster\"\nhtml_theme = \"sphinx_rtd_theme\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\n# pygments_style = \"sphinx\"\n"
  },
  {
    "path": "docs/getting_started.rst",
    "content": "Dataset Zoo\n##################\nLAVIS inherently supports a wide variety of common language-vision datasets by providing automatic download scripts to help download and organize these datasets; \nand implements PyTorch datasets for these datasets. To view supported datasets, use the following code:\n\n.. code-block:: python\n\n    from lavis.datasets.builders import dataset_zoo\n    dataset_names = dataset_zoo.get_names()\n    print(dataset_names)\n    # ['aok_vqa', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m',\n    #  'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'imagenet', 'laion2B_multi',\n    #  'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr',\n    #  'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa']\n    print(len(dataset_names))\n    # 23\n\n\nAuto-Downloading and Loading Datasets\n######################################\nWe now take COCO caption dataset as an example to demonstrate how to download and prepare the dataset.\n\nIn ``lavis/datasets/download_scripts/``, we provide tools to download most common public language-vision datasets supported by LAVIS.\nThe COCO caption dataset uses images from COCO dataset. Therefore, we first download COCO images via:\n\n.. code-block:: bash\n    \n    cd lavis/datasets/download_scripts/ && python download_coco.py\n\nThis will automatically download and extract COCO images to the default LAVIS cache location.\nThe default cache location is ``~/.cache/lavis``, defined in ``lavis/configs/default.yaml``.\n\nAfter downloading the images, we can use ``load_dataset()`` to obtain the dataset. On the first run, this will automatically download and cache annotation files.\n\n.. code-block:: python\n\n    from lavis.datasets.builders import load_dataset\n    coco_dataset = load_dataset(\"coco_caption\")\n\n    print(coco_dataset.keys())\n    # dict_keys(['train', 'val', 'test'])\n\n    print(len(coco_dataset[\"train\"]))\n    # 566747\n\n    print(coco_dataset[\"train\"][0])\n    # {'image': <PIL.Image.Image image mode=RGB size=640x480>,\n    #  'text_input': 'A woman wearing a net on her head cutting a cake. ',\n    #  'image_id': 0}\n\nIf you already host a local copy of the dataset, you can pass in the ``vis_path`` argument to change the default location to load images.\n\n.. code-block:: python\n\n    coco_dataset = load_dataset(\"coco_caption\", vis_path=YOUR_LOCAL_PATH)\n\n\nModel Zoo\n####################################\nLAVIS supports a growing list of pre-trained models for different tasks,\ndatatsets and of varying sizes. Let's get started by viewing the supported models.\n\n.. code-block:: python\n\n    from lavis.models import model_zoo\n    print(model_zoo)\n    # ==================================================\n    # Architectures                  Types\n    # ==================================================\n    # albef_classification           base, ve\n    # albef_nlvr                     base\n    # albef_pretrain                 base\n    # albef_retrieval                base, coco, flickr\n    # albef_vqa                      base, vqav2\n    # alpro_qa                       base, msrvtt, msvd\n    # alpro_retrieval                base, msrvtt, didemo\n    # blip_caption                   base, base_coco, large, large_coco\n    # blip_classification            base\n    # blip_feature_extractor         base\n    # blip_nlvr                      base\n    # blip_pretrain                  base\n    # blip_retrieval                 base, coco, flickr\n    # blip_vqa                       base, vqav2\n    # clip                           ViT-B-32, ViT-B-16, ViT-L-14, ViT-L-14-336, RN50\n\n    # show total number of support model variants\n    len(model_zoo)\n    # 33\n\n\nInference with Pre-trained Models\n####################################\n\nNow let's see how to use models in LAVIS to perform inference on example data. We first\nload a sample image from local.\n\n.. code-block:: python\n\n    from PIL import Image\n\n    # setup device to use\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    # load sample image\n    raw_image = Image.open(\"docs/_static/merlion.png\").convert(\"RGB\")\n\nThis example image shows `Merlion park <https://en.wikipedia.org/wiki/Merlion>`_ (`image credit <https://theculturetrip.com/asia/singapore/articles/what-exactly-is-singapores-merlion-anyway/>`_), a landmark in Singapore.\n\n.. image:: _static/merlion.png\n\nImage Captioning\n*******************************\nWe now use the BLIP model to generate a caption for the image. To make inference even easier, we also associate each\npre-trained model with its preprocessors (transforms),  we use ``load_model_and_preprocess()`` with the following arguments:\n\n- ``name``: The name of the model to load. This could be a pre-trained model, task model, or feature extractor. See ``model_zoo`` for a full list of model names.\n- ``model_type``: Each architecture has variants trained on different datasets and at different scale. See Types column in ``model_zoo`` for a full list of model types.\n- ``is_eval``: if `True`, set the model to evaluation mode. This is desired for inference or feature extraction.\n- ``devce``: device to load the model to.\n\n.. code-block:: python\n\n    from lavis.models import load_model_and_preprocess\n    # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.\n    # this also loads the associated image processors\n    model, vis_processors, _ = load_model_and_preprocess(name=\"blip_caption\", model_type=\"base_coco\", is_eval=True, device=device)\n\n    # preprocess the image\n    # vis_processors stores image transforms for \"train\" and \"eval\" (validation / testing / inference)\n    image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n\n    # generate caption\n    model.generate({\"image\": image})\n    # ['a large fountain spewing water into the air']\n\n\nYou may also load models and their preprocessors separately via ``load_model()`` and ``load_processor()``.\nIn BLIP, you can also generate diverse captions by turning nucleus sampling on.\n\n.. code-block:: python\n\n    from lavis.processors import load_processor\n    from lavis.models import load_model\n\n    # load image preprocesser used for BLIP\n    vis_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n    model = load_model(name=\"blip_caption\", model_type=\"base_coco\", is_eval=True, device=device)\n\n    image = vis_processor(image).unsqueeze(0).to(device)\n    model.generate({\"image\": raw_image}, use_nucleus_sampling=True)\n    # one generated random sample: ['some very pretty buildings and some water jets']\n\n\nVisual question answering (VQA)\n*******************************\nBLIP model is able to answer free-form questions about images in natural language.\nTo access the VQA model, simply replace the ``name`` and ``model_type`` arguments \npassed to ``load_model_and_preprocess()``.\n\n.. code-block:: python\n\n    from lavis.models import load_model_and_preprocess\n    model, vis_processors, txt_processors = load_model_and_preprocess(name=\"blip_vqa\", model_type=\"vqav2\", is_eval=True, device=device)\n\n    # ask a random question.\n    question = \"Which city is this photo taken?\"\n    \n    image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n    question = txt_processors[\"eval\"](question)\n\n    model.predict_answers(samples={\"image\": image, \"text_input\": question}, inference_method=\"generate\")\n    # ['singapore']\n\n\nUnified Feature Extraction Interface\n####################################\n\nLAVIS provides a unified interface to extract multimodal features from each architecture.\nTo extract features, we load the feature extractor variants of each model.\nThe multimodal feature can be used for multimodal classification. The low-dimensional unimodal features can be used to compute cross-modal similarity.\n\n.. code-block:: python\n\n    from lavis.models import load_model_and_preprocess \n    \n    model, vis_processors, txt_processors = load_model_and_preprocess(name=\"blip_feature_extractor\", model_type=\"base\", is_eval=True, device=device)\n    caption = \"a large fountain spewing water into the air\"\n\n    image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n    text_input = txt_processors[\"eval\"](caption)\n\n    sample = {\"image\": image, \"text_input\": [text_input]}\n\n    features_multimodal = model.extract_features(sample)\n    print(features_multimodal.keys())\n    # odict_keys(['image_embeds', 'multimodal_embeds'])\n    print(features_multimodal.multimodal_embeds.shape)\n    # torch.Size([1, 12, 768]), use features_multimodal[:, 0, :] for multimodal classification tasks\n\n    features_image = model.extract_features(sample, mode=\"image\")\n    print(features_image.keys())\n    # odict_keys(['image_embeds', 'image_embeds_proj'])\n    print(features_image.image_embeds.shape)\n    # torch.Size([1, 197, 768])\n    print(features_image.image_embeds_proj.shape)\n    # torch.Size([1, 197, 256])\n\n    features_text = model.extract_features(sample, mode=\"text\")\n    print(features_text.keys())\n    # odict_keys(['text_embeds', 'text_embeds_proj'])\n    print(features_text.text_embeds.shape)\n    # torch.Size([1, 12, 768])\n    print(features_text.text_embeds_proj.shape)\n    # torch.Size([1, 12, 256])\n    \n    similarity = features_image.image_embeds_proj[:, 0, :] @ features_text.text_embeds_proj[:, 0, :].t()\n    print(similarity)\n    # tensor([[0.2622]])\n\nSince LAVIS supports a unified feature extraction interface, minimal changes are necessary to use a different model as feature extractor. For example,\nto use ALBEF as the feature extractor, one only needs to change the following line:\n\n.. code-block:: python\n\n    model, vis_processors, txt_processors = load_model_and_preprocess(name=\"albef_feature_extractor\", model_type=\"base\", is_eval=True, device=device)\n\nSimilarly, to use CLIP as feature extractor: \n\n.. code-block:: python\n\n    model, vis_processors, txt_processors = load_model_and_preprocess(name=\"clip_feature_extractor\", model_type=\"base\", is_eval=True, device=device)\n    # model, vis_processors, txt_processors = load_model_and_preprocess(name=\"clip_feature_extractor\", model_type=\"RN50\", is_eval=True, device=device)\n    # model, vis_processors, txt_processors = load_model_and_preprocess(name=\"clip_feature_extractor\", model_type=\"ViT-L-14\", is_eval=True, device=device)\n"
  },
  {
    "path": "docs/index.rst",
    "content": ".. LAVIS documentation master file, created by\n   sphinx-quickstart on Sun Jul 31 10:32:27 2022.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to LAVIS's documentation!\n=================================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Introduction\n\n   intro\n\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Getting Started\n\n   getting_started\n\n\n..    :maxdepth: 1\n..    :caption: Advanced Training\n\n..    advanced_training\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Advanced Usage\n\n   benchmark\n   tutorial\n\n\n.. Documentations\n.. ===================\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/intro.rst",
    "content": "What is LAVIS?\n####################################\n\nLAVIS is a Python deep learning library for LAnguage-and-VISion research and applications.\nIt features a unified design to access state-of-the-art foundation language-vision models (`ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,\n`BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_), common tasks \n(retrieval, captioning, visual question answering, multimodal classification etc.) and datasets (COCO, Flickr, Nocaps, Conceptual\nCommons, SBU, etc.).\n\nThis library aims to provide engineers and researchers with a one-stop solution to rapidly develop models for their specific multimodal\nscenarios, and benchmark them across standard and customized datasets. \n\nKey features of LAVIS include:\n\n- **Modular and Extensible Library Design**: facilitating to easily utilize and repurpose existing modules (datasets, models, preprocessors), also to add new modules.\n\n- **Easy Off-the-shelf Inference and Feature Extraction**: readily available pre-trained models let you take advantage of state-of-the-art multimodal understanding and generation capabilities on your own data.\n\n- **Reproducible Model Zoo**: provided training/pre-training recipies to easily replicate and extend state-of-the-art models.\n\n- **Dataset Zoo and Automatic Downloading Tools**: it can be a hassle to prepare the many language-vision datasets. LAVIS provides automatic downloaing scripts to help prepare a large variety of datasets and their annotations.\n\nOther features include:\n\n- **Distributed Training** using multiple GPUs on one machine or across multiple machines.\n\n- **Web Demo**: try supported models on your own pictures, questions etc.\n\n- **Leaderboard**: comparing state-of-the-art models across standard datasets. \n\n- **Dataset Explorer**: help browse and understand language-vision datasets.\n\nSupported Tasks, Models and Datasets\n####################################\n\nThe following table shows the supported models and language-vision tasks by LAVIS. Adapting existing models to more tasks is possible and next to come in future releases.\n\n======================================== =========================== ============================================= ============ \nTasks                                     Supported Models            Supported Datasets                            Modalities  \n======================================== =========================== ============================================= ============ \nImage-text Pre-training                   ALBEF, BLIP                 COCO, VisualGenome, SBU, ConceptualCaptions  image, text  \nImage-text Retrieval                      ALBEF, BLIP, CLIP           COCO, Flickr30k                              image, text  \nText-image Retrieval                      ALBEF, BLIP, CLIP           COCO, Flickr30k                              image, text  \nVisual Question Answering                 ALBEF, BLIP                 VQAv2, OKVQA, A-OKVQA                        image, text  \nImage Captioning                          BLIP                        COCO, NoCaps                                 image, text  \nImage Classification                      CLIP                        ImageNet                                     image        \nNatural Language Visual Reasoning (NLVR)  ALBEF, BLIP                 NLVR2                                        image, text  \nVisual Entailment (VE)                    ALBEF                       SNLI-VE                                      image, text  \nVisual Dialogue                           BLIP                        VisDial                                      image, text  \nVideo-text Retrieval                      BLIP, ALPRO                 MSRVTT, DiDeMo                               video, text  \nText-video Retrieval                      BLIP, ALPRO                 MSRVTT, DiDeMo                               video, text  \nVideo Question Answering (VideoQA)        BLIP, ALPRO                 MSRVTT, MSVD                                 video, text  \nVideo Dialogue                            VGD-GPT                     AVSD                                         video, text  \nMultimodal Feature Extraction             ALBEF, CLIP, BLIP, ALPRO    customized                                   image, text  \n======================================== =========================== ============================================= ============ \n\nLibrary Design\n####################################\n\n.. image:: _static/architecture.png\n  :width: 550\n\nLAVIS has six key modules.\n\n- ``lavis.runners`` manages the overall training and evaluation lifecycle. It is also responsible for creating required components lazily as per demand, such as optimizers, learning rate schedulers and dataloaders. Currently ``RunnerBase`` implements epoch-based training and ``RunerIters`` implements iteration-based training.\n- ``lavis.tasks`` implements concrete training and evaluation logic per task. A task could be, for example, retrieval, captioning, pre-training. The rationale to have an abstraction of task is to accomodate task-specific training and evaluation. For example, evaluating a retrieval model is different from a classification model.\n- ``lavis.datasets`` is responsible for creating datasets, where ``lavis.datasets.builders`` loads dataset configurations, downloads annotations and returns a dataset object; ``lavis.datasets.datasets`` defines the supported datasets, each is a ``torch.utils.data.Dataset`` instance. We also provide `automatic dataset downloading tools` in ``datasets/download_scripts`` to help prepare common public datasets.\n- ``lavis.models`` holds definition for the supported models and shared model layers.\n- ``lavis.processors`` handles preprocessing of text and images/videos before feeding the model. For images and videos, a processor can be thought as transfroms in torchvision; for text input, this may include lowering case, truncation etc.\n- ``lavis.common`` module contains shared classes and methods used by multiple other modules. For example,\n\n   - ``lavis.common.config`` contains classes to store and manipulate configuration files used by LAVIS. In particular, we use a hierarchical configuration design, to allow highly customizable training and evaluation.\n   - ``lavis.common.registry``  serves as a centralized place to manage modules that share the same functionalities. It allows building datasets, models, tasks, and learning rate schedulers during runtime, by specifying their names as string in the configuration file.\n   - ``lavis.common.optims`` contains definitions of learning rate schedulers.\n   - ``lavis.common.dist_utils`` contains utilities for distributed training and evaluation.\n   - ``lavis.common.utils`` contains miscellaneous utilities, mostly IO-related helper functions.\n\n\nInstallation\n############\n1. (Optional) Creating conda environment\n\n.. code-block:: bash\n\n   conda create -n lavis python=3.8\n   conda activate lavis\n\n2. Cloning and building from source\n\n.. code-block:: bash\n\n   git clone https://github.com/salesforce/LAVIS.git\n   cd LAVIS\n   pip install .\n\nIf you would like to develop on LAVIS, you may find it easier to build with editable mode::\n\n   pip install -e .\n\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=source\r\nset BUILDDIR=build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.http://sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "GitPython\nipykernel\nnbsphinx==0.8.7\npandoc\nsphinx\nsphinx_autodoc_typehints\nsphinx_rtd_theme"
  },
  {
    "path": "docs/tutorial.configs.rst",
    "content": ".. _config:\n\nTraining Models on Task Datasets (Commands and Configurations) \n#################################################################\n\nLAVIS provides scripts to pre-train and finetune supported models on standard language-vision tasks, stored at ``lavis/run_scripts/``. \nTo replicate the experiments, just run these bash scripts. For example, to train BLIP model on the image-text retrieval task with MSCOCO dataset, we can run\n\n.. code-block::\n\n    bash run_scripts/lavis/blip/train/train_retrieval_coco.sh\n\nInside the scripts, we can see \n\n.. code-block:: bash\n\n    python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/retrieval_coco_ft.yaml\n\nwhere we start a pytorch distributed training on 8 GPUs (you may change according to your own hardware setup). The ``--cfg-path`` specifys a `runtime configuration file`, specifying\nthe task, model, dataset and training recipes. \n\nAvailable options and their descriptions are as below.\n\n.. LAVIS executes training and evaluation based on arguments specified in the configuration files. The default model and dataset configurations are defined in ``lavis/configs``. The task-specific configurations are defined in ``lavis/projects``. Task-specific configurations have higher priority over the default configurations.\n\n.. The following tables provide explanations for the arguments in the configuration files.\n\n.. list-table::\n   :widths: 30 40\n   :header-rows: 1\n\n   * - Model Configurations\n     - Functionalities\n   * - arch\n     - | name of the model from the model zoo\n       | default: task-dependent\n   * - model_type\n     - | the type of the model (e.g., base)\n       | default: task-dependent\n   * - load_pretrained\n     - | load pretrained weights\n       | default: True (for finetuning task) | False (for pretraining task) \n   * - load_finetuned\n     - | load task-specific finetuned weights\n       | default: False (for finetuning task) | True (for evaluation) \n   * - pretrained \n     - | URL or local path which stores the pretrained model, defined in the default model configuration file\n       | default: task-dependent \n   * - finetuned\n     - | URL or local path which stores the finetuned model, defined in the default model configuration file\n       | default: task-dependent\n\n.. list-table::\n   :widths: 30 50\n   :header-rows: 1\n\n   * - Dataset Configurations\n     - Functionalities\n   * - vis_processor\n     - | pre-processing of visual input\n       | default: task-dependent\n   * - text_processor\n     - | pre-processing of text input\n       | default: task-dependent\n   * - build_info\n     - | dataset information including the storage location, defined in the default dataset configuration file\n       | default: task-dependent\n\n.. list-table::\n   :widths: 30 50\n   :header-rows: 1\n\n   * - Runtime Configurations\n     - Functionalities\n   * - task\n     - | name of the task\n       | default: task-dependent\n   * - lr_sched\n     - | learning rate schedular\n       | default: linear_warmup_cosine_lr\n   * - init_lr\n     - | initial learning rate (after warmup)\n       | default: task-dependent\n   * - min_lr\n     - | final learning rate after decay\n       | default: task-dependent\n   * - warmup_lr\n     - | starting learning rate for warmup\n       | default: init_lr (no warmup)\n   * - lr_decay_rate\n     - | learning rate decay per epoch for step_lr_shedule\n       | default: 0.9\n   * - warmup_steps\n     - | number of steps for learning rate warmup\n       | default: 0\n   * - max_epoch\n     - | total number of training epochs\n       | default: task-dependent\n   * - weight_decay\n     - | weight decay coefficient for the optimizer\n       | default: 0.05\n   * - batch_size_train\n     - | batch size during training\n       | default: task-dependent\n   * - batch_size_eval\n     - | batch size during evaluation\n       | default: task-dependent\n   * - seed\n     - | pseudo random number generator seed\n       | default: 42\n   * - output_dir\n     - | directory to store logs, results and checkpoints\n       | default: task-dependent\n   * - resume_ckpt_path\n     - | path of the checkpoint to resume training from\n       | default: None\n   * - evaluate\n     - | only perform evaluation without training\n       | default: False\n   * - train_splits\n     - | dataset splits used for training\n       | default: [\"train\"]\n   * - valid_splits\n     - | dataset splits used for validation\n       | default: [\"val\"]\n   * - test\n     - | dataset splits used for test\n       | default: [\"test\"]\n   * - device\n     - | use cpu or gpu (cuda)\n       | default: cuda\n   * - world_size\n     - | number of processes participating in the job\n       | default: 1\n   * - dist_url\n     - | URL specifying how to initialize the process group\n       | default: \"env://\"\n   * - distributed\n     - | use distributed training\n       | default: True\n   * - amp\n     - | use automatic mixed precision training\n       | default: False\n\n.. list-table::\n   :widths: 40 50\n   :header-rows: 1\n\n   * - Text Generation Configurations\n     - Functionalities\n   * - max_len\n     - | maximum number of text tokens to generate\n       | default: 20 (for image captioning)\n   * - min_len\n     - | minimum number of text tokens to generate\n       | default: 5 (for image captioning)\n   * - num_beams\n     - | number of beams to perform beam search\n       | default: 3\n\n.. list-table::\n   :widths: 40 50\n   :header-rows: 1\n\n   * - Multimodal Retrieval Configurations\n     - Functionalities\n   * - negative_all_rank\n     - | collect negatives from all processes for the image-text matching loss\n       | default: True (for coco)\n   * - k_test\n     - | number of retrieval candidates ranked from contrastive similarity\n       | default: 256 (for coco)\n"
  },
  {
    "path": "docs/tutorial.datasets.rst",
    "content": "Adding Datasets\n################################################\n\nThis is a tutorial on adding a new dataset using ``lavis.datasets`` module. \n\nThe LAVIS library includes a standard dataset module, which allows customization to add new datasets. \nThe ``lavis.datasets`` module is designed such that any new dataset class can be easily added and adapted from our code base, including creating dataset configuration, and defining and associating new dataset classes.\n\nIn this tutorial, we will replicate the steps to add a dataset class for the `Audio-Visual Scene-Aware Dialogue (AVSD) <https://arxiv.org/pdf/1901.09107.pdf>`_ benchmark for the video-grounded dialogue task.\n\nDataset Configuration ``lavis.configs.datasets``\n**************************************************************\n\nFirst, we define the basic configurations for this dataset, including a new dataset class ``avsd_dialogue``, dataset card, and data types. \nWe can define any new dataset configuration in ``lavis.configs.datasets``. For instance, under this module, we can set up a configuration file ``avsd/defaults_dial.yaml`` as follows:  \n\n.. code-block:: yaml\n\n    datasets:\n      avsd_dialogue: # name of the dataset builder\n        dataset_card: dataset_card/avsd_dialogue.md # path to the dataset card \n        data_type: features # [images|videos|features] we use features in this case for extracted video features \n\n        build_info:\n          # Be careful not to append minus sign (-) before split to avoid itemizing\n          annotations:\n            train:\n              url: /export/home/data/avsd/train_set4DSTC7-AVSD.json\n              storage: avsd/annotations/train.json\n            val:\n              url: /export/home/data/avsd/valid_set4DSTC7-AVSD.json\n              storage: avsd/annotations/val.json \n            test:\n              url: /export/home/data/avsd/test_set4DSTC7-AVSD.json\n              storage: avsd/annotations/test.json \n          features:\n            storage: /export/home/data/avsd/features/ \n\n\nDataset Card\n===============\nOne optional step to set up dataset configuration is defining a dataset card, which contains more details about the dataset such as description, tasks, and metrics. \nFor instance, we can define a dataset card for the AVSD benchmark in ``dataset_card/avsd_dialogue.md``.\nDepending on the dataset, we included in its corresponding dataset card the command for auto-downloading data (with python code defined in ``lavis.datasets.download_scripts``) that will automatically load the data and store it in a specific folder.\nElse, you should describe in the dataset card the external download instructions from the original data source to load the dataset properly. \n\nOne example of a dataset card for the AVSD benchmark is: \n\n.. code-block:: md\n\n    ![Samples from the AVSD dataset (Image credit: \"https://arxiv.org/pdf/1901.09107.pdf\").](imgs/avsd_dialogue.png)(Samples from the AVSD dataset. Image credit: \"https://arxiv.org/pdf/1901.09107.pdf\")\n    \n    # Audio-Visual Scene-Aware Dialogues (AVSD) \n    \n    ## Description\n    [Audio-Visual Scene-Aware Dialogues (AVSD)](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) contains more than 10,000 dialogues, each of which is grounded on a unique video. In the test split, for each test sample, 6 reference dialogue responses are provided. \n    \n    \n    ## Task\n    \n    (https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge)\n    \n    In a **video-grounded dialogue task**, the system must generate responses to user input in the context of a given dialog.\n    This context consists of a dialog history (previous utterances by both user and system) in addition to video and audio information that comprise the scene. The quality of a system’s automatically generated sentences is evaluated using objective measures to determine whether or not the generated responses are natural and informative\n    \n    ## Metrics\n    Models are typically evaluated according to [BLEU](https://aclanthology.org/P02-1040/), [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf), [METEOR](https://aclanthology.org/W05-0909/), and [ROUGE-L](https://aclanthology.org/W04-1013/) metrics. \n    \n    ## Leaderboard\n    \n    ....\n    \n    \n    ## Auto-Downloading\n    \n    Please refer to [benchmark webite](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) for instructions to download the dataset. \n    \n    \n    ## References\n    \"Audio Visual Scene-Aware Dialog\", Huda Alamri, Vincent Cartillier, Abhishek Das, Jue Wang, Anoop Cherian, Irfan Essa, Dhruv Batra, Tim K. Marks, Chiori Hori, Peter Anderson, Stefan Lee, Devi Parikh\n\nVisual Data Type\n==============================\nWe currently limit the visual data types to one of three options: ``images``, ``videos``, and ``features``. \n\"Images\" and \"videos\" refer to the raw visual data, which is appropriate for models processing visual data in their original forms (e.g. ViT models). \n\"Features\" are visual representations extracted from pretrained models (e.g. CNN models). \nIn this tutorial, the AVSD benchmark consists of video features extracted from 3D-CNN models. \n\nBuild Info\n==============================\nBuild info refers to the specific locations where data is stored and cached. \n\nFor text annotations (e.g. captioning or dialogues), by default, we include three data splits, namely \"train\", \"val\", and \"test\", typically used in all machine learning projects. \nFor each split, we specify 2 parameters: ``url``  and ``storage``.\n``url`` can be either an online URL where the dataset can be loaded automatically (e.g. from *googleapis*), or a local directory where data is already downloaded beforehand. \n``storage`` is the directory where the data will be cached over time, avoiding downloading data repeatedly.\n\nFor visual data annotations, ensure the field name matches the data types defined earlier (e.g. one of \"images\", \"videos\" or features\"). \nAs visual features are usually large and should be downloaded beforehand, we maintain only a ``storage`` parameter where visual data is cached. \n\nDataset ``lavis.datasets.datasets``\n**************************************************************\n\nBase Dataset ``lavis.datasets.datasets.base_dataset``\n=======================================================\nIn this step, we want to define new dataset classes that inherit our base dataset class ``lavis.datasets.datasets.base_dataset``. This base dataset class already defines standard methods such as ``collater`` which uses the default collator from Pytorch. \n\n.. code-block:: python\n\n    import json\n    from typing import Iterable\n    \n    from torch.utils.data import Dataset, ConcatDataset\n    from torch.utils.data.dataloader import default_collate\n        \n    class BaseDataset(Dataset):\n        def __init__(\n            self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]\n        ):\n            \"\"\"\n            vis_root (string): Root directory of images (e.g. coco/images/)\n            ann_root (string): directory to store the annotation file\n            \"\"\"\n            self.vis_root = vis_root\n    \n            self.annotation = []\n            for ann_path in ann_paths:\n                self.annotation.extend(json.load(open(ann_path, \"r\")))\n    \n            self.vis_processor = vis_processor\n            self.text_processor = text_processor\n    \n            self._add_instance_ids()\n    \n        def __len__(self):\n            return len(self.annotation)\n    \n        def collater(self, samples):\n            return default_collate(samples)\n    \n        def set_processors(self, vis_processor, text_processor):\n            self.vis_processor = vis_processor\n            self.text_processor = text_processor\n    \n        def _add_instance_ids(self, key=\"instance_id\"):\n            for idx, ann in enumerate(self.annotation):\n                ann[key] = str(idx)\n\nAny dataset subclass will inherit these methods and it is optional to define and overwrite these methods accordingly to the specifications of the dataset. \nWe encourage users not to modify the base dataset class as any modification will have cascading impacts on any other dataset classes that inherit this base dataset. \nInstead, the users should independently create new dataset classes to cater to their specific requirements. \n\nDialogue Datasets ``lavis.datasets.datasets.dialogue_datasets``\n======================================================================\n\nFor example, for the AVSD dataset, we want to define a new dataset subclass ``DialogueDataset`` for dialogue tasks. We can define this dataset class in ``lavis.datasets.datasets.dialogue_datasets`` as following: \n\n.. code-block:: python\n\n    import os\n    from collections import OrderedDict\n        \n    from lavis.datasets.datasets.base_dataset import BaseDataset\n    \n    import json \n    import copy \n\n    class DialogueDataset(BaseDataset):\n        def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n            \"\"\"\n            vis_processor (string): visual processor \n            text_processor (string): textual processor \n            vis_root (string): Root directory of images (e.g. coco/images/)\n            ann_paths (string): Root directory of images (e.g. coco/images/)\n            \"\"\"\n                \n            self.vis_root = vis_root\n    \n            self.annotation = []\n            for ann_path in ann_paths:\n                dialogs = json.load(open(ann_path, \"r\"))['dialogs']\n                for dialog in dialogs: \n                    all_turns = dialog['dialog']\n                    dialogue_context = [] \n                    for turn in all_turns: \n                        dialog_instance = copy.deepcopy(dialog)\n                        question = turn['question']\n                        answer = turn['answer'] \n                        \n                        dialog_instance['dialog'] = copy.deepcopy(dialogue_context) \n                        dialog_instance['question'] = question\n                        dialog_instance['answer'] = answer \n                        self.annotation.append(dialog_instance)\n                        dialogue_context.append(turn)\n                        \n            self.vis_processor = vis_processor\n            self.text_processor = text_processor\n    \n            self._add_instance_ids()\n    \n            self.img_ids = {}\n            n = 0\n            for ann in self.annotation:\n                img_id = ann[\"image_id\"]\n                if img_id not in self.img_ids.keys():\n                    self.img_ids[img_id] = n\n                    n += 1\n\nClass inheritance allows us to define multiple subclasses. For instance, we want another dialogue dataset class that is defined only for the test split. We can define another dataset class ``DialogueEvalDataset`` as similarly defined above but the annotations are processed differently. \nTypically, in dialogue tasks, during test time, only a single test sample is constructed per dialogue (rather than decomposing all dialogue turns as samples during training time).\nThe dataset class can then be defined as: \n\n.. code-block:: python\n\n    class DialogueEvalDataset(BaseDataset):\n        def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n            # ...\n            # defined similarly as DialogueDataset above \n            # except for the loading of dialogue annotation data            \n    \n            self.annotation = []\n            for ann_path in ann_paths:\n                dialogs = json.load(open(ann_path, \"r\"))['dialogs']\n                for dialog in dialogs: \n                    all_turns = dialog['dialog']\n                    dialogue_context = all_turns[:-1]\n                    last_turn = all_turns[-1] \n                    \n                    question = last_turn['question']\n                    answer = last_turn['answer'] \n                        \n                    dialog['dialog'] = dialogue_context\n                    dialog['question'] = question\n                    dialog['answer'] = answer\n                                        \n                    self.annotation.append(dialog)\n\n\nUsing class inheritance to define datasets also allows us to develop more fine-grain class implementations, each of which is specifically designated for a benchmark. \nFor instance, under the dialogue-based tasks, we can further define another dataset subclass that is specified for the AVSD dataset. \nWe can define a new class ``AVSDDialDataset`` that further specifies how to load individual samples and collate them accordingly to specific requirements: \n\n.. code-block:: python\n\n    import os\n    from lavis.datasets.datasets.base_dataset import BaseDataset\n    from lavis.datasets.datasets.dialogue_datasets import DialogueDataset, DialogueEvalDataset\n    \n    import torch \n        \n    class AVSDDialDataset(DialogueDataset):\n        def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n\n            super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n    \n        def __getitem__(self, index):\n    \n            ann = self.annotation[index]\n    \n            vname = ann[\"image_id\"]\n    \n            video = self.vis_processor(self.vis_root, vname)\n            \n            dialogue = self.text_processor(ann)\n            \n            return {\n                \"video_fts\": video['video_fts'],\n                \"video_token_type_ids\": video['token_type_ids'], \n                \"input_ids\": dialogue['input_ids'], \n                \"token_type_ids\": dialogue['token_type_ids'],\n                \"labels\": dialogue['labels'], \n                \"image_id\": ann[\"image_id\"],\n                \"instance_id\": ann[\"instance_id\"]\n            }\n        \n        def collater(self, samples):\n            \n            input_ids, token_type_ids, labels, video_fts, video_token_type_ids = [], [], [], [], []\n            \n            for i in samples:\n                input_ids.append(i['input_ids'])\n                token_type_ids.append(i['token_type_ids'])\n                labels.append(i['labels'])\n                video_fts.append(i['video_fts'])\n                video_token_type_ids.append(i['video_token_type_ids'])\n    \n            input_ids = self.text_processor.padding(input_ids)\n            \n            labels = self.text_processor.padding(labels, -1)\n            video_fts = self.vis_processor.padding(video_fts)\n            \n            token_type_ids = self.text_processor.padding(token_type_ids)\n            video_token_type_ids = self.text_processor.padding(video_token_type_ids)\n            token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)\n            \n            attn_mask = self.text_processor.get_attention_mask(input_ids)\n            video_mask = self.vis_processor.get_attention_mask(video_fts)\n            attn_mask = torch.cat([video_mask, attn_mask], dim=1)\n            \n            video_labels = torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 # ignore token indice -1 by default \n\n            labels = torch.cat([video_labels, labels], dim=1)\n            \n            samples = {}\n            samples['input_ids'] = input_ids\n            samples['token_type_ids'] = token_type_ids\n            samples['labels'] = labels\n            samples['video_fts'] = video_fts\n            samples['attn_mask'] = attn_mask\n            \n            return samples  \n\nNote that in a dataset subclass, if methods such as ``__getitem__`` and ``collater`` are not defined, the same functions from the corresponding superclass will be used. \nFor instance, by default, we always use the collater from the ``BaseDataset`` class to collate data samples. \n\nDataset Builder ``lavis.datasets.builders``\n**************************************************************\nDataset Builder is the data processing module that controls the dataset classes (by training or evaluation split) and associates the specific dataset configurations to these dataset classes. \n\nBase Dataset Builder ``lavis.datasets.builders.base_dataset_builder``\n======================================================================\n\nNote that any new builder class definition should inherit the base dataset builder class ``lavis.datasets.builders.base_dataset_builder``:\n\n.. code-block:: python\n\n    class BaseDatasetBuilder:\n        train_dataset_cls, eval_dataset_cls = None, None\n        ...\n\nThis allows us to standardize the operations of dataset builders across all builder classes. We advise the users to carefully review the standard methods defined in the base builder class, including methods such as ``_download_data`` and ``build_dataset`` that will load download the data and create instances of dataset classes: \n\n.. code-block:: python\n\n    class BaseDatasetBuilder:\n    ...\n\n        def build_datasets(self):\n            # download, split, etc...\n            # only called on 1 GPU/TPU in distributed\n    \n            if is_main_process():\n                self._download_data()\n    \n            if is_dist_avail_and_initialized():\n                dist.barrier()\n    \n            # at this point, all the annotations and image/videos should be all downloaded to the specified locations.\n            logging.info(\"Building datasets...\")\n            datasets = self.build()  # dataset['train'/'val'/'test']\n            \n            return datasets\n    \n        def _download_data(self):\n            self._download_ann()\n            self._download_vis()\n    \nWe encourage users not to modify the implementation of the base dataset builder class as this will affect all existing dataset builder subclasses.\n\nDialogue Dataset Builder ``lavis.datasets.builders.dialogue_builder``\n======================================================================\nWe can define any new builder subclass and associate this builder with the corresponding dataset classes and dataset configurations. \nFor instance, for the AVSD dataset, we can define a builder ``lavis.datasets.builders.dialogue_builder`` for dialogue-based datasets as follows: \n\n.. code-block:: python\n\n    from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\n    from lavis.datasets.datasets.avsd_dialogue_datasets import (\n        AVSDDialDataset, \n        AVSDDialEvalDataset \n    )\n    \n    from lavis.common.registry import registry\n    \n    \n    @registry.register_builder(\"avsd_dialogue\")\n    class AVSDDialBuilder(BaseDatasetBuilder):\n        train_dataset_cls = AVSDDialDataset \n        eval_dataset_cls = AVSDDialEvalDataset \n    \n        DATASET_CONFIG_DICT = {\n            \"default\": \"configs/datasets/avsd/defaults_dial.yaml\"\n        }\n\nNote that we chose to separately define the parameters ``train_dataset_cls`` and  ``eval_dataset_cls`` to consider cases where data is processed differently between training and test time. \nFor instance, in captioning tasks, during test time, each data sample often includes multiple ground-truth captions rather than just a single ground-truth during training time. \nIf the data processing is the same in both training and test time, the two parameters can be linked to the same dataset class. \n\nFinally, define ``DATASET_CONFIG_DICT`` to associate the dataset configurations to the assigned dataset classes. \n\nRegistering Builder ``lavis.datasets.builders.__init__``\n======================================================================\n\nTo add a new builder class, ensure to first include the class within the ``__init__.py``. For instance, to define a new builder for the AVSD dataset: \n\n.. code-block:: python\n\n    from lavis.datasets.builders.dialogue_builder import (\n        AVSDDialBuilder\n    )\n    \n    __all__ = [\n        ...,\n        \"AVSDDialBuilder\"\n    ]\n\nAssigning Builder \n======================================================================\nNote that during data loading and processing, the builder being assigned must have the correct registry to be able to load it properly. \nFor instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``: \n\n.. code-block:: yaml\n\n    datasets:\n      avsd_dialogue: # name of the dataset builder\n        ...\n        # processor configuration \n        ...\n\nSubsequently, any processes (e.g. training) should load this configuration file to assign the correct builder which will then associate the correct dataset classes to construct data samples. \n\n.. code-block:: sh\n\n    python train.py --cfg-path dialogue_avsd_ft.yaml\n"
  },
  {
    "path": "docs/tutorial.evaluation.rst",
    "content": "Evaluating Pre-trained Models on Task Datasets\n###############################################\nLAVIS provides pre-trained and finetuned model for off-the-shelf evaluation on task dataset. \nLet's now see an example to evaluate BLIP model on the captioning task, using MSCOCO dataset.\n\n.. _prep coco:\n\nPreparing Datasets\n******************\nFirst, let's download the dataset. LAVIS provides `automatic downloading scripts` to help prepare \nmost of the public dataset, to download MSCOCO dataset, simply run\n\n.. code-block:: bash\n\n    cd lavis/datasets/download_scripts && bash download_coco.py\n\nThis will put the downloaded dataset at a default cache location ``cache`` used by LAVIS.\n\nIf you want to use a different cache location, you can specify it by updating ``cache_root`` in ``lavis/configs/default.yaml``.\n\nIf you have a local copy of the dataset, it is recommended to create a symlink from the cache location to the local copy, e.g.\n\n.. code-block:: bash\n\n    ln -s /path/to/local/coco cache/coco\n\nEvaluating pre-trained models\n******************************\n\nTo evaluate pre-trained model, simply run\n\n.. code-block:: bash\n\n    bash run_scripts/lavis/blip/eval/eval_coco_cap.sh\n\nOr to evaluate a large model:\n\n.. code-block:: bash\n\n    bash run_scripts/lavis/blip/eval/eval_coco_cap_large.sh"
  },
  {
    "path": "docs/tutorial.models.rst",
    "content": "Adding Models\n####################################\n\nThis is a tutorial on adding new models using ``lavis.models`` module.\n\nThe LAVIS library includes a standard model module that builds the foundation for many major language-vision models such as `ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,\n`BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, and `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_. \nThe ``lavis.models`` module is designed such that any new models can be added and integrated into the LAVIS library, with minimal steps to develop training and testing procedures. \nIn this tutorial, we will replicate the steps to add a GPT-style model specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_. \n\nBase Model ``lavis.models.base_model``\n**************************************************************\n\nNote that any new model definition should inherit the base model class ``BaseModel``:\n\n.. code-block:: python\n\n    from omegaconf import OmegaConf\n    \n    import numpy as np\n    \n    import torch\n    import torch.nn as nn\n    \n    from lavis.common.utils import get_abs_path\n    \n    class BaseModel(nn.Module):\n        \"\"\"Base class for models.\"\"\"\n    \n        def __init__(self):\n            super().__init__()\n    \n        def forward_features(self, *args, **kwargs):\n            \"\"\"Similar to *forward* but only return features.\"\"\"\n            raise NotImplementedError\n    \n        def load_from_pretrained(self, url_or_filename):\n            raise NotImplementedError\n    \n        @classmethod\n        def _from_config(cls, cfg=None, model_type=\"base\"):\n            if not cfg:\n                # useful when building model without a provided configuration file\n                cfg = OmegaConf.load(cls.default_config_path(model_type)).model\n    \n            return cls.from_config(cfg)\n    \n        @classmethod\n        def from_pretrained(cls, model_type=\"base\"):\n            \"\"\"\n            Build a pretrained model from the default configuration file, specified by model_type.\n            \"\"\"\n            return cls._from_config(cfg=None, model_type=model_type)\n    \n        @property\n        def device(self):\n            return list(self.parameters())[0].device\n    \n        @classmethod\n        def default_config_path(cls, model_type=\"base\"):\n            assert (\n                model_type in cls.PRETRAINED_MODEL_CONFIG_DICT\n            ), \"Unknown model type {}\".format(model_type)\n            return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])\n    \n        def before_evaluation(self, **kwargs):\n            pass\n    \n        def show_n_params(self, return_str=True):\n            tot = 0\n            for p in self.parameters():\n                w = 1\n                for x in p.shape:\n                    w *= x\n                tot += w\n            if return_str:\n                if tot >= 1e6:\n                    return \"{:.1f}M\".format(tot / 1e6)\n                else:\n                    return \"{:.1f}K\".format(tot / 1e3)\n            else:\n                return tot\n\n\nIn this base model, we already declare and standardize many common methods such as ``_from_config`` and ``_from_pretrained``. \nInheriting this base model class allows us to standardize operations of models across all model classes while still allowing customizations. \nWe advise users not to change the implementation of the base model class as this will affect all existing model subclasses.\n\nGPT-style Video-grounded Dialogue Model ``lavis.models.gpt_models.gpt_dialogue``\n********************************************************************************\n\nIn this step, we can define a new model class, e.g. under ``lavis.models.gpt_models.gpt_dialogue``, for GPT-based dialogue models designed specifically for video-grounded dialogues. \nNote that we assume the model class inherits from the standard model super class ``GPT2LMHeadModel`` from the ``transformers`` `library <https://huggingface.co/docs/transformers/index>`_.\nWe also enforce model integration to the LAVIS framework through the inheritance of the ``BaseModel`` from the LAVIS library, as the secondary super class.\n\n.. code-block:: python\n\n    import torch\n    from lavis.common.registry import registry\n    from lavis.models.base_model import BaseModel\n    \n    from transformers import GPT2Model, GPT2LMHeadModel\n    from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions\n    import math\n    import torch\n    import torch.nn as nn\n    from torch.nn import CrossEntropyLoss, MSELoss\n        \n    @registry.register_model(\"gpt_dialogue\")\n    class GPTDialogue(GPT2LMHeadModel, BaseModel):\n        ...\n \nNext, we can modify the architecture of the model during model initialization to fit the tasks of interest, i.e. video-grounded dialogues. \nIn this case, we want to add additional model parameters for a linear network to transform the video feature representations to the model dimension. \n\n.. code-block:: python\n\n    class GPTDialogue(GPT2LMHeadModel, BaseModel):\n\n        def __init__(self, config, len_video_ft=4224):\n            \n            super().__init__(config)\n            \n            self.video_ff = nn.Linear(len_video_ft, config.n_embd)\n       \n            # Model parallel\n            self.model_parallel = False\n            self.device_map = None\n    \n            # Initialize weights and apply final processing\n            self.post_init()\n    \nNote that for each new model class, we advise redefining the ``from_config`` method which is inherited from the ``BaseModel`` class.\nAs each model usually has its own unique configurations, redefining the method will ensure the model instances are created properly. \nFor instance, ``GPTDialogue`` requires an additional parameter of video feature length (``len_video_ft``) which should be part of the model initialization procedure. \nAnother additional parameter is the number of tokens/words (as we include additional special tokens in the vocabulary for dialogue tasks). \n\n.. code-block:: python\n\n    class GPTDialogue(GPT2LMHeadModel, BaseModel):\n        ...\n        @classmethod\n        def from_config(cls, cfg):\n            model = cls.from_pretrained('gpt2', len_video_ft=cfg['len_video_ft']) \n            model.resize_token_embeddings(cfg['len_tokenizer'])\n            return model\n\nOther basic methods should also be defined explicitly in the new model class, including the ``forward`` function. \nFor instance, in GPT models for video-grounded dialogue tasks, we want the forward operation also includes the transformation and integration of video features before passing the representations to the Transformer layers. \n\n.. code-block:: python\n\n    class GPTDialogue(GPT2LMHeadModel, BaseModel):\n        ...\n\n        def forward(self, samples, \n                    past_key_values=None,\n                    position_ids=None,\n                    head_mask=None,\n                    encoder_hidden_states=None,\n                    encoder_attention_mask=None,\n                    use_cache=None,\n                    output_attentions=None,\n                    output_hidden_states=None,\n                    return_dict=None):        \n                \n                input_embs = self.transformer.wte(samples['input_ids'])\n                video_embs = self.video_ff(samples['video_fts'])\n                input_embs = torch.cat([video_embs, input_embs], dim=1)\n                        \n                transformer_outputs = self.transformer(\n                    attention_mask=samples['attn_mask'],\n                    token_type_ids=samples['token_type_ids'],\n                    inputs_embeds=input_embs,\n                    position_ids=position_ids,\n                    head_mask=head_mask,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                    output_hidden_states=output_hidden_states,\n                    return_dict=return_dict,\n                )\n                hidden_states = transformer_outputs[0]\n            \n                lm_logits = self.lm_head(hidden_states)\n                ...\n\nRegistering New Model ``lavis.models.__init__``\n********************************************************************************\n\nAny new model must be officially registered as part of the ``lavis.models`` module. \nFor instance, to add a model class for GPT-based dialogue models, we can modify the ``__init__.py`` as follows:\n\n.. code-block:: python\n\n    from lavis.models.gpt_models.gpt_dialogue import GPTDialogue\n    \n    __all__ = [\n        ...\n        \"GPTDialogue\"\n    ]\n\nAssigning Model\n********************************************************************************\n\nFrom the above example of a model class, note that we define a ``from_config method`` for the new model class. \nThis method will process a configuration file and pass specific parameters to initialize the model classes properly. \nTo do this, we can assign/ associate the correct registry of model classes in a configuration file. \nFor instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:\n\n.. code-block:: yaml\n\n    model:\n      arch: gpt_dialogue # name of the model \n      model_type: base\n\n\nSubsequently, any processes (e.g. training) should load this configuration file to assign the correct model.\n\n.. code-block:: sh\n\n    python train.py --cfg-path dialogue_avsd_ft.yaml\n\nNote that to simplify the model configuration, we only enable two main parameters here: ``arch`` and ``model_type``. ``arch`` refers to the model class registry, and ``model_type`` is the corresponding model type under this model family.\nFor instance, with ``gpt_dialogue``, we have a model ``base`` which has its own configuration in a separate configuration file e.g. ``gpt_dialogue_base.yaml``:\n\n.. code-block:: yaml\n\n    model:\n      arch: gpt_dialogue\n      len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens       \n      len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128 \n\nWe can pass load this configuration and pass the parameters to the above ``from_config`` method to initialize the model accordingly. \nWe advise the users to maintain a dictionary that contains default paths to model configurations, in the model class definition. \nBy default, the LAVIS framework will search for configurations from each model class defined as ``model.PRETRAINED_MODEL_CONFIG_DICT``.\n\n.. code-block:: python\n\n    class GPTDialogue(GPT2LMHeadModel, BaseModel):\n        PRETRAINED_MODEL_CONFIG_DICT = {\n                \"base\": \"configs/models/gpt_dialogue_base.yaml\"\n            }\n        ...\n"
  },
  {
    "path": "docs/tutorial.processors.rst",
    "content": "Adding Processors\n################################################\n\nThis is a tutorial on adding new processors using ``lavis.processors`` module. \n\nThe LAVIS library includes a standard processor module that preprocesses data e.g. image transformation and sequence concatenation.\nThe ``lavis.processors`` module is designed such that any processors can be added, specifically to the requirements of corresponding models of interest. \nIn this tutorial, we will replicate the steps to add visual and textual processors specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_. \nIn addition, we also want the processors to have processing features to make the data samples compatible with GPT-style models.\n\nBase Processor ``lavis.processors.base_processors``\n*****************************************************\n\nNote that any new processor definition should inherit the base processor class ``BaseProcessor``:\n\n.. code-block:: python\n\n    from omegaconf import OmegaConf\n    \n    class BaseProcessor:\n        def __init__(self):\n            self.transform = lambda x: x\n            return\n    \n        def __call__(self, item):\n            return self.transform(item)\n    \n        @classmethod\n        def from_config(cls, cfg=None):\n            return cls()\n    \n        def build(self, **kwargs):\n            cfg = OmegaConf.create(kwargs)\n    \n            return self.from_config(cfg)\n\nThis allows us to standardize operations of processors across all processor classes while still allowing customization of processors specifically to data and model types. \nWe encourage users not to modify the implementation of the base processor class as this will have an impact on all existing processor subclasses.\n\nGPT-style Processors ``lavis.processors.gpt_processors``\n**************************************************************\nIn this step, we can define new processor classes, e.g. under ``lavis.processors.gpt_processors``, for GPT models designed specifically for video-grounded dialogues. \nFirst, we want to process video features by defining ``GPTVideoFeatureProcessor`` class.\nIn this tutorial, we assume video features are extracted beforehand and this processor simply loads the features from ``npy`` files.\nOther methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple video samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models). \n\n.. code-block:: python \n\n    SPECIAL_TOKENS_DICT = {'bos_token': \"<bos>\", 'eos_token': \"<eos>\", 'additional_special_tokens': [\"<speaker1>\", \"<speaker2>\", \"<video>\", \"<cap>\"], 'pad_token': \"<pad>\"}\n    ...\n\n    @registry.register_processor(\"gpt_video_ft\")\n    class GPTVideoFeatureProcessor(BaseProcessor):\n        def __init__(self, visual_ft, audio_ft):\n\n            self.visual_ft = visual_ft\n            self.audio_ft = audio_ft\n\n            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n            self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) \n                    \n        def padding(self, seq):\n            padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=1.0) \n            return padded_seq\n        \n        def get_attention_mask(self, seq):\n            return torch.sum(seq != 1, dim=2) != 0\n    \n        def __call__(self, ft_root, vname):\n            all_ft = []\n            \n            for ft_name in self.visual_ft:\n                ft_path = os.path.join(ft_root, ft_name, vname)\n                all_ft.append(np.load(ft_path + '.npy'))\n            \n            for ft_name in self.audio_ft: \n                ft_path = os.path.join(ft_root, ft_name, vname)\n                all_ft.append(np.load(ft_path + '.npy'))\n            \n            min_len = min([len(ft) for ft in all_ft])\n            \n            sampled_ft = [ft[:min_len] for ft in all_ft]\n            sampled_ft = np.concatenate(sampled_ft, axis=1)\n            item = {} \n            item['video_fts'] = torch.Tensor(sampled_ft) \n            \n            video_type_token = self.tokenizer.convert_tokens_to_ids('<video>')\n            item['token_type_ids'] = torch.Tensor([video_type_token] * len(sampled_ft)).long() \n            \n            return item \n    \n        @classmethod\n        def from_config(cls, cfg=None):\n            if cfg is None:\n                cfg = OmegaConf.create()\n            \n            visual_ft = cfg.get(\"visual_ft\", [\"i3d_rgb\"])\n            audio_ft = cfg.get(\"audio_ft\", [\"vggish\"])\n            \n            return cls(\n                visual_ft=visual_ft,\n                audio_ft=audio_ft\n            )\n\nAnother processor class that will be useful to have is to process dialogue data. Here we can define a ``GPTDialogueProcessor`` class.\nThis processor class receives raw annotations and constructs inputs as a concatenation of input sequences (questions, dialogue contexts, and responses) to facilitate application in GPT models. \nOther methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple sequence samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models). \n\n.. code-block:: python \n\n    SPECIAL_TOKENS_DICT = {'bos_token': \"<bos>\", 'eos_token': \"<eos>\", 'additional_special_tokens': [\"<speaker1>\", \"<speaker2>\", \"<video>\", \"<cap>\"], 'pad_token': \"<pad>\"}\n    ...\n\n    @registry.register_processor(\"gpt_dialogue\")\n    class GPTDialogueProcessor(BaseProcessor):\n        def __init__(self, max_turns=3, use_caption=True):\n            self.max_turns = max_turns \n            self.use_caption = use_caption \n            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n            self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) \n            \n        def sample_sequence(self, caption, history, answer):\n            bos, eos, speaker1, speaker2, cap = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-2])\n            instance = {}\n            sequence = [caption] + history + [answer]\n            sequence = [s + [eos] for s in sequence] \n    \n            instance[\"input_ids\"] = list(chain(*sequence))\n            instance[\"token_type_ids\"] = [cap] * len(sequence[0]) + [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence[1:]) for _ in s]\n            instance[\"labels\"] = ([-1]*sum(len(s) for s in sequence[:-1])) + sequence[-1]\n            \n            assert len(instance[\"input_ids\"])==len(instance[\"token_type_ids\"])\n            assert len(instance[\"token_type_ids\"])==len(instance[\"labels\"])\n            \n            for k,v in instance.items():\n                instance[k] = torch.Tensor(v).long() \n            \n            return instance \n        \n        def padding(self, seq, pad_token=-1):\n            if pad_token==-1: pad_token = self.tokenizer.pad_token_id \n            padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=pad_token) \n            return padded_seq\n        \n        def get_attention_mask(self, seq, pad_token=-1):\n            if pad_token==-1: pad_token = self.tokenizer.pad_token_id \n            return seq != pad_token\n        \n        def __call__(self, ann):\n            if self.use_caption:\n                caption = ' '.join([ann['caption'], ann['summary']])\n                caption = self.tokenizer.encode(caption)\n            else:\n                caption = []\n                \n            dial_history = []\n            for turn in ann['dialog'][-self.max_turns:]:\n                dial_history.append(turn['question'])\n                dial_history.append(turn['answer'])\n            dial_history.append(ann['question'])\n            dial_history = [self.tokenizer.encode(t) for t in dial_history]\n            \n            answer = self.tokenizer.encode(ann['answer'])\n            \n            item = self.sample_sequence(caption, dial_history, answer)\n            \n            return item \n    \n        @classmethod\n        def from_config(cls, cfg=None):\n            if cfg is None:\n                cfg = OmegaConf.create()\n    \n            use_caption = cfg.get(\"use_caption\", True)\n            max_turns = cfg.get(\"max_turns\", 3)\n    \n            return cls(max_turns=max_turns, use_caption=use_caption)\n\nRegistering New Processors ``lavis.processors.__init__``\n**************************************************************\n\nFinally, any new processor must be officially registered as part of the ``lavis.processors`` module. \nFor instance, to add processor classes for GPT-based dialogue models, including one for dialogue data ``GPTDialogueProcessor`` and one for video features ``GPTVideoFeatureProcessor``, we can modify the ``__init__.py`` as follows: \n\n.. code-block:: python\n\n    from lavis.processors.gpt_processors import (\n        GPTVideoFeatureProcessor,\n        GPTDialogueProcessor,\n    )\n    \n    __all__ = [\n        ...\n        # GPT\n        \"GPTVideoFeatureProcessor\",\n        \"GPTDialogueProcessor\"\n    ]\n\nAssigning Processors \n**************************************************************\nFrom the above example of processor classes, note that we define a ``from_config`` method for each class. \nThis method will process a configuration file and pass specific parameters e.g. ``max_turns``, ``visual_ft``, to initialize the processor classes properly. \nTo do this, we can assign/ associate the correct registry of processor classes in a configuration file.\nFor instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:\n\n.. code-block:: yaml \n\n    datasets:\n      avsd_dialogue: # name of the dataset builder\n        vis_processor:\n            train:\n              name: \"gpt_video_ft\" # name of the visual processor for training data\n              visual_ft: [\"i3d_flow\", \"i3d_rgb\"]  \n              audio_ft: [\"vggish\"]    \n            eval:\n              name: \"gpt_video_ft\" # name of the visual processor for evaluation data\n              visual_ft: [\"i3d_flow\", \"i3d_rgb\"]  \n              audio_ft: [\"vggish\"]   \n        text_processor:\n            train:\n              name: \"gpt_dialogue\" # name of the textual processor for training data\n              max_turns:  3\n              use_caption: True \n            eval:\n              name: \"gpt_dialogue\" # name of the textual processor for evaluation data\n              max_turns:  3\n              use_caption: True \n\nSubsequently, any processes (e.g. training) should load this configuration file to assign the correct processors.\n\n.. code-block:: sh\n\n    python train.py --cfg-path dialogue_avsd_ft.yaml\n"
  },
  {
    "path": "docs/tutorial.rst",
    "content": "Tutorials\n==============================\n\n.. toctree::\n   :maxdepth: 1\n\n   tutorial.evaluation\n   tutorial.training-example\n   tutorial.configs\n   tutorial.datasets\n   tutorial.processors\n   tutorial.models\n   tutorial.tasks\n"
  },
  {
    "path": "docs/tutorial.tasks.rst",
    "content": "Adding Tasks\n####################################\n\nThis is a tutorial on adding new machine learning tasks using ``lavis.tasks`` module.\n\nThe LAVIS library includes a standard task module that centralizes the model training and evaluation procedure of machine learning tasks. \nThe ``lavis.tasks`` module is designed such that any new tasks can be added and integrated, catering to any customization in the training and testing procedures. \nIn this tutorial, we will replicate the steps to add a new task into LAVIS for the `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_. \n\nBase Task ``lavis.tasks.base_task``\n********************************************************************************\n\nNote that any new model definition should inherit the base task class ``BaseTask``:\n\n.. code-block:: python\n\n    import logging\n    import os\n    \n    import torch.distributed as dist\n    from lavis.common.dist_utils import get_rank, get_world_size, is_main_process\n    from lavis.common.logger import MetricLogger, SmoothedValue\n    from lavis.common.registry import registry\n    from lavis.datasets.data_utils import prepare_sample\n    \n    class BaseTask:\n        def __init__(self, **kwargs):\n            super().__init__()\n    \n            self.inst_id_key = \"instance_id\"\n    \n        @classmethod\n        def setup_task(cls, **kwargs):\n            return cls()\n    \n        def build_model(self, cfg):\n            model_config = cfg.model_cfg\n    \n            model_cls = registry.get_model_class(model_config.arch)\n            return model_cls.from_config(model_config)\n    \n        def build_datasets(self, cfg):\n            \"\"\"\n            Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.\n            Download dataset and annotations automatically if not exist.\n    \n            Args:\n                cfg (common.config.Config): _description_\n    \n            Returns:\n                dict: Dictionary of torch.utils.data.Dataset objects by split.\n            \"\"\"\n    \n            datasets = dict()\n    \n            datasets_config = cfg.datasets_cfg\n    \n            assert len(datasets_config) > 0, \"At least one dataset has to be specified.\"\n    \n            for name in datasets_config:\n                dataset_config = datasets_config[name]\n    \n                builder = registry.get_builder_class(name)(dataset_config)\n                dataset = builder.build_datasets()\n    \n                datasets[name] = dataset\n    \n            return datasets\n    \n        def train_step(self, model, samples):\n            loss = model(samples)[\"loss\"]\n            return loss\n    \n        ...\n\nIn this base task, we already declare and standardize many common methods such as ``train_step``, ``build_model``, and ``build_datasets``. \nInheriting this base task class allows us to standardize operations of tasks across all task classes.\nWe recommend users not change the implementation of the base task class as this will have an impact on all existing task subclasses.\n\nDialogue Task ``lavis.tasks.dialogue``\n********************************************************************************\n\nIn this step, we can define a new task class, e.g. under ``lavis.tasks.dialogue``, for video-grounded dialogues.\nFor instance, we define a new task class ``DialogueTask`` that inherits the super task class ``BaseTask``.\n\n.. code-block:: python\n\n    import json\n    import os\n    \n    from lavis.common.dist_utils import main_process\n    from lavis.common.logger import MetricLogger\n    from lavis.common.registry import registry\n    from lavis.tasks.base_task import BaseTask\n    from lavis.datasets.data_utils import prepare_sample\n    \n    import numpy as np \n    \n    @registry.register_task(\"dialogue\")\n    class DialogueTask(BaseTask):\n        def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):\n            super().__init__()\n    \n            self.num_beams = num_beams\n            self.max_len = max_len\n            self.min_len = min_len\n            self.evaluate = evaluate\n    \n            self.report_metric = report_metric\n    \n        @classmethod\n        def setup_task(cls, cfg):\n            run_cfg = cfg.run_cfg\n    \n            num_beams = run_cfg.num_beams\n            max_len = run_cfg.max_len\n            min_len = run_cfg.min_len\n            evaluate = run_cfg.evaluate\n    \n            report_metric = run_cfg.get(\"report_metric\", True)\n    \n            return cls(\n                num_beams=num_beams,\n                max_len=max_len,\n                min_len=min_len,\n                evaluate=evaluate,\n                report_metric=report_metric,\n            )\n    \n        def valid_step(self, model, samples):\n            results = []        \n            loss = model(samples)[\"loss\"].item() \n            \n            return [loss] \n        ...\n\nNote that for any new task, we advise the users to review carefully the functions implemented within ``BaseTask`` and consider which methods should be modified. \nFor instance, the base task class already contains a standard implementation of model training steps that are common among machine learning steps. \nSome major methods we want to emphasize and should be customized by each task are the ``valid_step`` and ``evaluation``. \nThese operations were not fully implemented in the base task class due to the differences in evaluation procedures among many machine learning tasks. \nAnother method that should be considered is the ``setup_task`` method. \nThis method will receive configurations that set task-specific parameters to initialize any task instance.\n\nRegistering New Task ``lavis.tasks.__init__`` \n********************************************************************************\n\nAny new task must be officially registered as part of the ``lavis.tasks`` module. For instance, to add a new task for video-grounded dialogues, we can modify the ``__init__.py`` as follows:\n\n.. code-block:: python\n\n    from lavis.tasks.dialogue import DialogueTask\n    \n    ...\n    __all__ = [\n        ...\n        \"DialogueTask\"\n    ]\n\nAssigning Task \n***************\n\nFrom the above example of task class, note that we define a ``setup_task`` method for each task class. \nThis method will process a configuration file and pass specific parameters e.g. ``num_beams`` (for beam search generative tasks during the inference stage), to initialize the task classes properly. \nTo assign and associate any task, we need to specify the correct registry of task classes in a configuration file. \nFor instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:\n\n.. code-block:: yaml\n\n    run:\n      task: dialogue # name of the task \n      \n      # optimizer\n      ...\n    \n      max_len: 20\n      min_len: 5\n      num_beams: 3    \n      ...\n    \nSubsequently, any processes (e.g. training) should load this configuration file to assign the correct task.\n\n.. code-block:: sh\n\n    python train.py --cfg-path dialogue_avsd_ft.yaml"
  },
  {
    "path": "docs/tutorial.training-example.rst",
    "content": "Example on Finetuning BLIP on COCO-Captioning\n################################################\n\nTo finetune BLIP model on the coco caption dataset, first refer to :ref:`prep coco` to prepare the dataset if you have not done so.\n\nTo finetune the model, we have prepared a run script for you, which can run as follows:\n\n.. code-block:: bash\n\n    bash run_scripts/lavis/blip/train/train_caption_coco_large.sh\n\nThis will finetune the pre-trained BLIP large model into a new model that can be used for captioning.\n\nDeep Dive\n**********\nNow let's take a closer look at the script and see what it does.\n\n.. code-block:: bash\n\n    python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/caption_coco_large_ft.yaml\n\nAs can be seen, the script simply calls the :code:`train.py` with PyTorch distributed training enabled.\nThe :code:`--cfg-path` argument specifies the **runtime config** file to use. The config file is a YAML file that specifies the training parameters, shown as follows:\n\n.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml\n    :language: yaml\n    :linenos:\n\nThe runtime config file is divided into 3 sections:\n    - :code:`model`: specifies the model architecture and type to use.\n    - :code:`data`: specifies the dataset to use.\n    - :code:`run`: specifies the runner arguments, such as tasks, optimizer, learning rate scheduler, etc.\n\nWe describe each section in detail below.\n\nModel configurations\n=====================\n\n.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml\n    :language: yaml\n    :linenos:\n    :lines: 6-10\n\nThe :code:`arch` argument specifies the model architecture to use. In this case, we use the :code:`blip_caption` architecture.\nYou can find available architectures by inspecting the :code:`model_zoo`.\nOnce the architecture is specified, the runner will look for the model class registered with the name and try to instantiate a model instance.\nIn this case :code:`BlipCaption` is the model registered with the name :code:`blip_caption`.\n\nThe registry maintains a mapping from the name string to the model class.\nThis allows the runner to find the model class dynamically based on the name string from the config file. \nThe following segment in :code:`lavis/models/blip_models/blip_caption.py` shows how :code:`BlipCaption` is registered with the name string :code:`blip_caption`:\n\n.. literalinclude:: ../lavis/models/blip_models/blip_caption.py\n    :language: python\n    :linenos:\n    :lines: 20-38\n\nOne same model architecture may be pre-trained or finetuned on different datasets or have different model configurations.\nFor example, :code:`BlipCaption` have:\n\n    - :code:`base_coco`: pre-trained base BLIP model adapated for COCO captioning finetuning.\n\n    - :code:`large_coco`: pre-trained large BLIP model adapated for COCO captioning finetuning.\n\nTherefore, we also need to specify :code:`model_type`. Here we use :code:`large_coco`.\nAnd we set :code:`load_finetuned` to :code:`False` to indicate that we are finetuning the model from the pre-trained weights.\nIf :code:`load_finetuned` set to :code:`True` as by default, the model will load finetuned weights on coco captioning.\n\nGiven the model architecture and type, the library will then look for the default model config for :code:`large_coco` in :code:`lavis/models/blip_models/blip_caption.py`.\nAs can be seen in the above code snippet, the corresponding config path is stored in :code:`BlipCaption.PRETRAINED_MODEL_CONFIG_DICT`. \nThen the library will load :code:`lavis/configs/models/blip_caption_large_coco.yaml` as the configuration to build the model.\n\n*Priority of Configs*: Note that the priority of the run config is higher than the default model config, meaning that arguments in the run config will override the default model config.\nFor example, in the default model config, :code:`load_finetuned` is set to :code:`True` by default, while in the run config, we set it to :code:`False` and finetuning from the pre-trained weights only.\n\n\nDataset configurations\n=========================\n\nThe second section of the config file specifies the dataset(s) to use.\n\n.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml\n    :language: yaml\n    :linenos:\n    :lines: 12-24\n\nWe associate each dataset with a :code:`vis_processor` and a :code:`text_processor`, responsible for processing the visual and textual input respectively.\nHere we again use the registry mechanism to dynamically load the processor class based on the name string.\nFor example, :code:`blip_image_train` is the name string for the :code:`BlipImageTrainProcessor` class, which is registered in :code:`lavis/processors/blip_processors.py`.\n\nSimilarly, the dataset name string is also registered in the registry, pointing to a dataset builder :code:`COCOCapBuilder` class.\nBy default, the builder will load the default dataset configuration as in :code:`DATASET_CONFIG_DICT`. You may also add new dataset types by adding new entries to the dictionary.\n\nThe dataset configuration used here is:\n\n.. literalinclude:: ../lavis/configs/datasets/coco/defaults_cap.yaml\n    :language: yaml\n    :linenos:\n    :lines: 6-28\n\nIn this configuration file, we specify the dataset name and mainly its building information.\nThe build information is divided into two parts: :code:`annotation` and :code:`images`. The annotation files will be automatically downloaded upon loading the dataset for the first time.\nThe :code:`images` part specifies the image root directory. This is a relative path to the cache directory, which is :code:`cache` by default. If you have a local copy of the dataset, you can specify the path to the local copy by\noverwriting the :code:`images` part in the runtime config file. For example, you may alter the run config as below to use your local dataset copy:\n\n.. code:: yaml\n\n    datasets:\n        coco_caption: # name of the dataset builder\n            vis_processor:\n                train:\n                name: \"blip_image_train\"\n                eval:\n                name: \"blip_image_eval\"\n            text_processor:\n                train:\n                name: \"blip_caption\"\n                prompt: \"a picture of \"\n                eval:\n                name: \"blip_caption\"\n            images:\n                YOUR_LOCAL_IMAGE_ROOT_DIR\n\nLAVIS supports using multiple datasets for training. See an example in :code:`lavis/projects/blip/train/pretrain_14m.yaml`.\n\n\nRunner configurations\n=========================\nThe last section of the config file specifies the arguments for the runner, shown below:\n\n.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml\n    :language: yaml\n    :linenos:\n    :lines: 26-56\n\nHere we specify runner-related arguments, including\n    - task-specific arguments, such as :code:`task`, :code:`max_len`, :code:`min_len`, etc.\n    - learning rate schedulers, optimizer;\n    - distributed training settings;\n    - logging and checkpointing settings.\n\nAvailable Configurations\n#########################\n\nSee :ref:`config` for the full list of available configurations and their descriptions."
  },
  {
    "path": "evaluate.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport argparse\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\n\nimport lavis.tasks as tasks\nfrom lavis.common.config import Config\nfrom lavis.common.dist_utils import get_rank, init_distributed_mode\nfrom lavis.common.logger import setup_logger\nfrom lavis.common.optims import (\n    LinearWarmupCosineLRScheduler,\n    LinearWarmupStepLRScheduler,\n)\nfrom lavis.common.utils import now\n\n# imports modules for registration\nfrom lavis.datasets.builders import *\nfrom lavis.models import *\nfrom lavis.processors import *\nfrom lavis.runners.runner_base import RunnerBase\nfrom lavis.tasks import *\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Training\")\n\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n\n    args = parser.parse_args()\n    # if 'LOCAL_RANK' not in os.environ:\n    #     os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef setup_seeds(config):\n    seed = config.run_cfg.seed + get_rank()\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n    cudnn.benchmark = False\n    cudnn.deterministic = True\n\n\ndef main():\n    # allow auto-dl completes on main process without timeout when using NCCL backend.\n    # os.environ[\"NCCL_BLOCKING_WAIT\"] = \"1\"\n\n    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.\n    job_id = now()\n\n    cfg = Config(parse_args())\n\n    init_distributed_mode(cfg.run_cfg)\n\n    setup_seeds(cfg)\n\n    # set after init_distributed_mode() to only log on master.\n    setup_logger()\n\n    cfg.pretty_print()\n\n    task = tasks.setup_task(cfg)\n    datasets = task.build_datasets(cfg)\n    model = task.build_model(cfg)\n\n    runner = RunnerBase(\n        cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets\n    )\n    runner.evaluate(skip_reload=True)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "lavis/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport sys\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.registry import registry\n\nfrom lavis.datasets.builders import *\nfrom lavis.models import *\nfrom lavis.processors import *\nfrom lavis.tasks import *\n\n\nroot_dir = os.path.dirname(os.path.abspath(__file__))\ndefault_cfg = OmegaConf.load(os.path.join(root_dir, \"configs/default.yaml\"))\n\nregistry.register_path(\"library_root\", root_dir)\nrepo_root = os.path.join(root_dir, \"..\")\nregistry.register_path(\"repo_root\", repo_root)\ncache_root = os.path.join(repo_root, default_cfg.env.cache_root)\nregistry.register_path(\"cache_root\", cache_root)\n\nregistry.register(\"MAX_INT\", sys.maxsize)\nregistry.register(\"SPLIT_NAMES\", [\"train\", \"val\", \"test\"])\n"
  },
  {
    "path": "lavis/common/config.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport json\nfrom typing import Dict\n\nfrom omegaconf import OmegaConf\nfrom lavis.common.registry import registry\n\n\nclass Config:\n    def __init__(self, args):\n        self.config = {}\n\n        self.args = args\n\n        # Register the config and configuration for setup\n        registry.register(\"configuration\", self)\n\n        user_config = self._build_opt_list(self.args.options)\n\n        config = OmegaConf.load(self.args.cfg_path)\n\n        runner_config = self.build_runner_config(config)\n        model_config = self.build_model_config(config, **user_config)\n        dataset_config = self.build_dataset_config(config)\n\n        # Validate the user-provided runner configuration\n        # model and dataset configuration are supposed to be validated by the respective classes\n        # [TODO] validate the model/dataset configuration\n        # self._validate_runner_config(runner_config)\n\n        # Override the default configuration with user options.\n        self.config = OmegaConf.merge(\n            runner_config, model_config, dataset_config, user_config\n        )\n\n    def _validate_runner_config(self, runner_config):\n        \"\"\"\n        This method validates the configuration, such that\n            1) all the user specified options are valid;\n            2) no type mismatches between the user specified options and the config.\n        \"\"\"\n        runner_config_validator = create_runner_config_validator()\n        runner_config_validator.validate(runner_config)\n\n    def _build_opt_list(self, opts):\n        opts_dot_list = self._convert_to_dot_list(opts)\n        return OmegaConf.from_dotlist(opts_dot_list)\n\n    @staticmethod\n    def build_model_config(config, **kwargs):\n        model = config.get(\"model\", None)\n        assert model is not None, \"Missing model configuration file.\"\n\n        model_cls = registry.get_model_class(model.arch)\n        assert model_cls is not None, f\"Model '{model.arch}' has not been registered.\"\n\n        model_type = kwargs.get(\"model.model_type\", None)\n        if not model_type:\n            model_type = model.get(\"model_type\", None)\n        # else use the model type selected by user.\n\n        assert model_type is not None, \"Missing model_type.\"\n\n        model_config_path = model_cls.default_config_path(model_type=model_type)\n\n        model_config = OmegaConf.create()\n        # hiararchy override, customized config > default config\n        model_config = OmegaConf.merge(\n            model_config,\n            OmegaConf.load(model_config_path),\n            {\"model\": config[\"model\"]},\n        )\n\n        return model_config\n\n    @staticmethod\n    def build_runner_config(config):\n        return {\"run\": config.run}\n\n    @staticmethod\n    def build_dataset_config(config):\n        datasets = config.get(\"datasets\", None)\n        if datasets is None:\n            raise KeyError(\n                \"Expecting 'datasets' as the root key for dataset configuration.\"\n            )\n\n        dataset_config = OmegaConf.create()\n\n        for dataset_name in datasets:\n            builder_cls = registry.get_builder_class(dataset_name)\n\n            dataset_config_type = datasets[dataset_name].get(\"type\", \"default\")\n            dataset_config_path = builder_cls.default_config_path(\n                type=dataset_config_type\n            )\n\n            # hiararchy override, customized config > default config\n            dataset_config = OmegaConf.merge(\n                dataset_config,\n                OmegaConf.load(dataset_config_path),\n                {\"datasets\": {dataset_name: config[\"datasets\"][dataset_name]}},\n            )\n\n        return dataset_config\n\n    def _convert_to_dot_list(self, opts):\n        if opts is None:\n            opts = []\n\n        if len(opts) == 0:\n            return opts\n\n        has_equal = opts[0].find(\"=\") != -1\n\n        if has_equal:\n            return opts\n\n        return [(opt + \"=\" + value) for opt, value in zip(opts[0::2], opts[1::2])]\n\n    def get_config(self):\n        return self.config\n\n    @property\n    def run_cfg(self):\n        return self.config.run\n\n    @property\n    def datasets_cfg(self):\n        return self.config.datasets\n\n    @property\n    def model_cfg(self):\n        return self.config.model\n\n    def pretty_print(self):\n        logging.info(\"\\n=====  Running Parameters    =====\")\n        logging.info(self._convert_node_to_json(self.config.run))\n\n        logging.info(\"\\n======  Dataset Attributes  ======\")\n        datasets = self.config.datasets\n\n        for dataset in datasets:\n            if dataset in self.config.datasets:\n                logging.info(f\"\\n======== {dataset} =======\")\n                dataset_config = self.config.datasets[dataset]\n                logging.info(self._convert_node_to_json(dataset_config))\n            else:\n                logging.warning(f\"No dataset named '{dataset}' in config. Skipping\")\n\n        logging.info(f\"\\n======  Model Attributes  ======\")\n        logging.info(self._convert_node_to_json(self.config.model))\n\n    def _convert_node_to_json(self, node):\n        container = OmegaConf.to_container(node, resolve=True)\n        return json.dumps(container, indent=4, sort_keys=True)\n\n    def to_dict(self):\n        return OmegaConf.to_container(self.config)\n\n\ndef node_to_dict(node):\n    return OmegaConf.to_container(node)\n\n\nclass ConfigValidator:\n    \"\"\"\n    This is a preliminary implementation to centralize and validate the configuration.\n    May be altered in the future.\n\n    A helper class to validate configurations from yaml file.\n\n    This serves the following purposes:\n        1. Ensure all the options in the yaml are defined, raise error if not.\n        2. when type mismatches are found, the validator will raise an error.\n        3. a central place to store and display helpful messages for supported configurations.\n\n    \"\"\"\n\n    class _Argument:\n        def __init__(self, name, choices=None, type=None, help=None):\n            self.name = name\n            self.val = None\n            self.choices = choices\n            self.type = type\n            self.help = help\n\n        def __str__(self):\n            s = f\"{self.name}={self.val}\"\n            if self.type is not None:\n                s += f\", ({self.type})\"\n            if self.choices is not None:\n                s += f\", choices: {self.choices}\"\n            if self.help is not None:\n                s += f\", ({self.help})\"\n            return s\n\n    def __init__(self, description):\n        self.description = description\n\n        self.arguments = dict()\n\n        self.parsed_args = None\n\n    def __getitem__(self, key):\n        assert self.parsed_args is not None, \"No arguments parsed yet.\"\n\n        return self.parsed_args[key]\n\n    def __str__(self) -> str:\n        return self.format_help()\n\n    def add_argument(self, *args, **kwargs):\n        \"\"\"\n        Assume the first argument is the name of the argument.\n        \"\"\"\n        self.arguments[args[0]] = self._Argument(*args, **kwargs)\n\n    def validate(self, config=None):\n        \"\"\"\n        Convert yaml config (dict-like) to list, required by argparse.\n        \"\"\"\n        for k, v in config.items():\n            assert (\n                k in self.arguments\n            ), f\"\"\"{k} is not a valid argument. Support arguments are {self.format_arguments()}.\"\"\"\n\n            if self.arguments[k].type is not None:\n                try:\n                    self.arguments[k].val = self.arguments[k].type(v)\n                except ValueError:\n                    raise ValueError(f\"{k} is not a valid {self.arguments[k].type}.\")\n\n            if self.arguments[k].choices is not None:\n                assert (\n                    v in self.arguments[k].choices\n                ), f\"\"\"{k} must be one of {self.arguments[k].choices}.\"\"\"\n\n        return config\n\n    def format_arguments(self):\n        return str([f\"{k}\" for k in sorted(self.arguments.keys())])\n\n    def format_help(self):\n        # description + key-value pair string for each argument\n        help_msg = str(self.description)\n        return help_msg + \", available arguments: \" + self.format_arguments()\n\n    def print_help(self):\n        # display help message\n        print(self.format_help())\n\n\ndef create_runner_config_validator():\n    validator = ConfigValidator(description=\"Runner configurations\")\n\n    validator.add_argument(\n        \"runner\",\n        type=str,\n        choices=[\"runner_base\", \"runner_iter\"],\n        help=\"\"\"Runner to use. The \"runner_base\" uses epoch-based training while iter-based\n            runner runs based on iters. Default: runner_base\"\"\",\n    )\n    # add argumetns for training dataset ratios\n    validator.add_argument(\n        \"train_dataset_ratios\",\n        type=Dict[str, float],\n        help=\"\"\"Ratios of training dataset. This is used in iteration-based runner.\n        Do not support for epoch-based runner because how to define an epoch becomes tricky.\n        Default: None\"\"\",\n    )\n    validator.add_argument(\n        \"max_iters\",\n        type=float,\n        help=\"Maximum number of iterations to run.\",\n    )\n    validator.add_argument(\n        \"max_epoch\",\n        type=int,\n        help=\"Maximum number of epochs to run.\",\n    )\n    # add arguments for iters_per_inner_epoch\n    validator.add_argument(\n        \"iters_per_inner_epoch\",\n        type=float,\n        help=\"Number of iterations per inner epoch. This is required when runner is runner_iter.\",\n    )\n    lr_scheds_choices = registry.list_lr_schedulers()\n    validator.add_argument(\n        \"lr_sched\",\n        type=str,\n        choices=lr_scheds_choices,\n        help=\"Learning rate scheduler to use, from {}\".format(lr_scheds_choices),\n    )\n    task_choices = registry.list_tasks()\n    validator.add_argument(\n        \"task\",\n        type=str,\n        choices=task_choices,\n        help=\"Task to use, from {}\".format(task_choices),\n    )\n    # add arguments for init_lr\n    validator.add_argument(\n        \"init_lr\",\n        type=float,\n        help=\"Initial learning rate. This will be the learning rate after warmup and before decay.\",\n    )\n    # add arguments for min_lr\n    validator.add_argument(\n        \"min_lr\",\n        type=float,\n        help=\"Minimum learning rate (after decay).\",\n    )\n    # add arguments for warmup_lr\n    validator.add_argument(\n        \"warmup_lr\",\n        type=float,\n        help=\"Starting learning rate for warmup.\",\n    )\n    # add arguments for learning rate decay rate\n    validator.add_argument(\n        \"lr_decay_rate\",\n        type=float,\n        help=\"Learning rate decay rate. Required if using a decaying learning rate scheduler.\",\n    )\n    # add arguments for weight decay\n    validator.add_argument(\n        \"weight_decay\",\n        type=float,\n        help=\"Weight decay rate.\",\n    )\n    # add arguments for training batch size\n    validator.add_argument(\n        \"batch_size_train\",\n        type=int,\n        help=\"Training batch size.\",\n    )\n    # add arguments for evaluation batch size\n    validator.add_argument(\n        \"batch_size_eval\",\n        type=int,\n        help=\"Evaluation batch size, including validation and testing.\",\n    )\n    # add arguments for number of workers for data loading\n    validator.add_argument(\n        \"num_workers\",\n        help=\"Number of workers for data loading.\",\n    )\n    # add arguments for warm up steps\n    validator.add_argument(\n        \"warmup_steps\",\n        type=int,\n        help=\"Number of warmup steps. Required if a warmup schedule is used.\",\n    )\n    # add arguments for random seed\n    validator.add_argument(\n        \"seed\",\n        type=int,\n        help=\"Random seed.\",\n    )\n    # add arguments for output directory\n    validator.add_argument(\n        \"output_dir\",\n        type=str,\n        help=\"Output directory to save checkpoints and logs.\",\n    )\n    # add arguments for whether only use evaluation\n    validator.add_argument(\n        \"evaluate\",\n        help=\"Whether to only evaluate the model. If true, training will not be performed.\",\n    )\n    # add arguments for splits used for training, e.g. [\"train\", \"val\"]\n    validator.add_argument(\n        \"train_splits\",\n        type=list,\n        help=\"Splits to use for training.\",\n    )\n    # add arguments for splits used for validation, e.g. [\"val\"]\n    validator.add_argument(\n        \"valid_splits\",\n        type=list,\n        help=\"Splits to use for validation. If not provided, will skip the validation.\",\n    )\n    # add arguments for splits used for testing, e.g. [\"test\"]\n    validator.add_argument(\n        \"test_splits\",\n        type=list,\n        help=\"Splits to use for testing. If not provided, will skip the testing.\",\n    )\n    # add arguments for accumulating gradient for iterations\n    validator.add_argument(\n        \"accum_grad_iters\",\n        type=int,\n        help=\"Number of iterations to accumulate gradient for.\",\n    )\n\n    # ====== distributed training ======\n    validator.add_argument(\n        \"device\",\n        type=str,\n        choices=[\"cpu\", \"cuda\"],\n        help=\"Device to use. Support 'cuda' or 'cpu' as for now.\",\n    )\n    validator.add_argument(\n        \"world_size\",\n        type=int,\n        help=\"Number of processes participating in the job.\",\n    )\n    validator.add_argument(\"dist_url\", type=str)\n    validator.add_argument(\"distributed\", type=bool)\n    # add arguments to opt using distributed sampler during evaluation or not\n    validator.add_argument(\n        \"use_dist_eval_sampler\",\n        type=bool,\n        help=\"Whether to use distributed sampler during evaluation or not.\",\n    )\n\n    # ====== task specific ======\n    # generation task specific arguments\n    # add arguments for maximal length of text output\n    validator.add_argument(\n        \"max_len\",\n        type=int,\n        help=\"Maximal length of text output.\",\n    )\n    # add arguments for minimal length of text output\n    validator.add_argument(\n        \"min_len\",\n        type=int,\n        help=\"Minimal length of text output.\",\n    )\n    # add arguments number of beams\n    validator.add_argument(\n        \"num_beams\",\n        type=int,\n        help=\"Number of beams used for beam search.\",\n    )\n\n    # vqa task specific arguments\n    # add arguments for number of answer candidates\n    validator.add_argument(\n        \"num_ans_candidates\",\n        type=int,\n        help=\"\"\"For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.\"\"\",\n    )\n    # add arguments for inference method\n    validator.add_argument(\n        \"inference_method\",\n        type=str,\n        choices=[\"genearte\", \"rank\"],\n        help=\"\"\"Inference method to use for question answering. If rank, requires a answer list.\"\"\",\n    )\n\n    # ====== model specific ======\n    validator.add_argument(\n        \"k_test\",\n        type=int,\n        help=\"Number of top k most similar samples from ITC/VTC selection to be tested.\",\n    )\n\n    return validator\n"
  },
  {
    "path": "lavis/common/dist_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport functools\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport timm.models.hub as timm_hub\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop(\"force\", False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef init_distributed_mode(args):\n    if \"RANK\" in os.environ and \"WORLD_SIZE\" in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ[\"WORLD_SIZE\"])\n        args.gpu = int(os.environ[\"LOCAL_RANK\"])\n    elif \"SLURM_PROCID\" in os.environ:\n        args.rank = int(os.environ[\"SLURM_PROCID\"])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print(\"Not using distributed mode\")\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = \"nccl\"\n    print(\n        \"| distributed init (rank {}, world {}): {}\".format(\n            args.rank, args.world_size, args.dist_url\n        ),\n        flush=True,\n    )\n    torch.distributed.init_process_group(\n        backend=args.dist_backend,\n        init_method=args.dist_url,\n        world_size=args.world_size,\n        rank=args.rank,\n        timeout=datetime.timedelta(\n            days=365\n        ),  # allow auto-downloading and de-compressing\n    )\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\ndef get_dist_info():\n    if torch.__version__ < \"1.0\":\n        initialized = dist._initialized\n    else:\n        initialized = dist.is_initialized()\n    if initialized:\n        rank = dist.get_rank()\n        world_size = dist.get_world_size()\n    else:  # non-distributed training\n        rank = 0\n        world_size = 1\n    return rank, world_size\n\n\ndef main_process(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        rank, _ = get_dist_info()\n        if rank == 0:\n            return func(*args, **kwargs)\n\n    return wrapper\n\n\ndef download_cached_file(url, check_hash=True, progress=False):\n    \"\"\"\n    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.\n    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.\n    \"\"\"\n\n    def get_cached_file_path():\n        # a hack to sync the file path across processes\n        parts = torch.hub.urlparse(url)\n        filename = os.path.basename(parts.path)\n        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)\n\n        return cached_file\n\n    if is_main_process():\n        timm_hub.download_cached_file(url, check_hash, progress)\n\n    if is_dist_avail_and_initialized():\n        dist.barrier()\n\n    return get_cached_file_path()\n"
  },
  {
    "path": "lavis/common/gradcam.py",
    "content": "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom scipy.ndimage import filters\nfrom skimage import transform as skimage_transform\n\n\ndef getAttMap(img, attMap, blur=True, overlap=True):\n    attMap -= attMap.min()\n    if attMap.max() > 0:\n        attMap /= attMap.max()\n    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode=\"constant\")\n    if blur:\n        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))\n        attMap -= attMap.min()\n        attMap /= attMap.max()\n    cmap = plt.get_cmap(\"jet\")\n    attMapV = cmap(attMap)\n    attMapV = np.delete(attMapV, 3, 2)\n    if overlap:\n        attMap = (\n            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img\n            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV\n        )\n    return attMap\n"
  },
  {
    "path": "lavis/common/logger.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport logging\nimport time\nfrom collections import defaultdict, deque\n\nimport torch\nimport torch.distributed as dist\n\nfrom lavis.common import dist_utils\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not dist_utils.is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device=\"cuda\")\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value,\n        )\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\n            \"'{}' object has no attribute '{}'\".format(type(self).__name__, attr)\n        )\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\"{}: {}\".format(name, str(meter)))\n        return self.delimiter.join(loss_str)\n\n    def global_avg(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\"{}: {:.4f}\".format(name, meter.global_avg))\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = \"\"\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt=\"{avg:.4f}\")\n        data_time = SmoothedValue(fmt=\"{avg:.4f}\")\n        space_fmt = \":\" + str(len(str(len(iterable)))) + \"d\"\n        log_msg = [\n            header,\n            \"[{0\" + space_fmt + \"}/{1}]\",\n            \"eta: {eta}\",\n            \"{meters}\",\n            \"time: {time}\",\n            \"data: {data}\",\n        ]\n        if torch.cuda.is_available():\n            log_msg.append(\"max mem: {memory:.0f}\")\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(\n                        log_msg.format(\n                            i,\n                            len(iterable),\n                            eta=eta_string,\n                            meters=str(self),\n                            time=str(iter_time),\n                            data=str(data_time),\n                            memory=torch.cuda.max_memory_allocated() / MB,\n                        )\n                    )\n                else:\n                    print(\n                        log_msg.format(\n                            i,\n                            len(iterable),\n                            eta=eta_string,\n                            meters=str(self),\n                            time=str(iter_time),\n                            data=str(data_time),\n                        )\n                    )\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print(\n            \"{} Total time: {} ({:.4f} s / it)\".format(\n                header, total_time_str, total_time / len(iterable)\n            )\n        )\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\ndef setup_logger():\n    logging.basicConfig(\n        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,\n        format=\"%(asctime)s [%(levelname)s] %(message)s\",\n        handlers=[logging.StreamHandler()],\n    )\n"
  },
  {
    "path": "lavis/common/optims.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport math\n\nfrom lavis.common.registry import registry\n\n\n@registry.register_lr_scheduler(\"linear_warmup_step_lr\")\nclass LinearWarmupStepLRScheduler:\n    def __init__(\n        self,\n        optimizer,\n        max_epoch,\n        min_lr,\n        init_lr,\n        decay_rate=1,\n        warmup_start_lr=-1,\n        warmup_steps=0,\n        **kwargs\n    ):\n        self.optimizer = optimizer\n\n        self.max_epoch = max_epoch\n        self.min_lr = min_lr\n\n        self.decay_rate = decay_rate\n\n        self.init_lr = init_lr\n        self.warmup_steps = warmup_steps\n        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr\n\n    def step(self, cur_epoch, cur_step):\n        if cur_epoch == 0:\n            warmup_lr_schedule(\n                step=cur_step,\n                optimizer=self.optimizer,\n                max_step=self.warmup_steps,\n                init_lr=self.warmup_start_lr,\n                max_lr=self.init_lr,\n            )\n        else:\n            step_lr_schedule(\n                epoch=cur_epoch,\n                optimizer=self.optimizer,\n                init_lr=self.init_lr,\n                min_lr=self.min_lr,\n                decay_rate=self.decay_rate,\n            )\n\n\n@registry.register_lr_scheduler(\"linear_warmup_cosine_lr\")\nclass LinearWarmupCosineLRScheduler:\n    def __init__(\n        self,\n        optimizer,\n        max_epoch,\n        min_lr,\n        init_lr,\n        warmup_steps=0,\n        warmup_start_lr=-1,\n        **kwargs\n    ):\n        self.optimizer = optimizer\n\n        self.max_epoch = max_epoch\n        self.min_lr = min_lr\n\n        self.init_lr = init_lr\n        self.warmup_steps = warmup_steps\n        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr\n\n    def step(self, cur_epoch, cur_step):\n        # assuming the warmup iters less than one epoch\n        if cur_epoch == 0:\n            warmup_lr_schedule(\n                step=cur_step,\n                optimizer=self.optimizer,\n                max_step=self.warmup_steps,\n                init_lr=self.warmup_start_lr,\n                max_lr=self.init_lr,\n            )\n        else:\n            cosine_lr_schedule(\n                epoch=cur_epoch,\n                optimizer=self.optimizer,\n                max_epoch=self.max_epoch,\n                init_lr=self.init_lr,\n                min_lr=self.min_lr,\n            )\n\n\ndef cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):\n    \"\"\"Decay the learning rate\"\"\"\n    lr = (init_lr - min_lr) * 0.5 * (\n        1.0 + math.cos(math.pi * epoch / max_epoch)\n    ) + min_lr\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\n\ndef warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):\n    \"\"\"Warmup the learning rate\"\"\"\n    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\n\ndef step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):\n    \"\"\"Decay the learning rate\"\"\"\n    lr = max(min_lr, init_lr * (decay_rate**epoch))\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n"
  },
  {
    "path": "lavis/common/registry.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n\nclass Registry:\n    mapping = {\n        \"builder_name_mapping\": {},\n        \"task_name_mapping\": {},\n        \"processor_name_mapping\": {},\n        \"model_name_mapping\": {},\n        \"lr_scheduler_name_mapping\": {},\n        \"runner_name_mapping\": {},\n        \"state\": {},\n        \"paths\": {},\n    }\n\n    @classmethod\n    def register_builder(cls, name):\n        r\"\"\"Register a dataset builder to registry with key 'name'\n\n        Args:\n            name: Key with which the builder will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n            from lavis.datasets.base_dataset_builder import BaseDatasetBuilder\n        \"\"\"\n\n        def wrap(builder_cls):\n            from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\n\n            assert issubclass(\n                builder_cls, BaseDatasetBuilder\n            ), \"All builders must inherit BaseDatasetBuilder class, found {}\".format(\n                builder_cls\n            )\n            if name in cls.mapping[\"builder_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"builder_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"builder_name_mapping\"][name] = builder_cls\n            return builder_cls\n\n        return wrap\n\n    @classmethod\n    def register_task(cls, name):\n        r\"\"\"Register a task to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n\n        def wrap(task_cls):\n            from lavis.tasks.base_task import BaseTask\n\n            assert issubclass(\n                task_cls, BaseTask\n            ), \"All tasks must inherit BaseTask class\"\n            if name in cls.mapping[\"task_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"task_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"task_name_mapping\"][name] = task_cls\n            return task_cls\n\n        return wrap\n\n    @classmethod\n    def register_model(cls, name):\n        r\"\"\"Register a task to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n\n        def wrap(model_cls):\n            from lavis.models import BaseModel\n\n            assert issubclass(\n                model_cls, BaseModel\n            ), \"All models must inherit BaseModel class\"\n            if name in cls.mapping[\"model_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"model_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"model_name_mapping\"][name] = model_cls\n            return model_cls\n\n        return wrap\n\n    @classmethod\n    def register_processor(cls, name):\n        r\"\"\"Register a processor to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n\n        def wrap(processor_cls):\n            from lavis.processors import BaseProcessor\n\n            assert issubclass(\n                processor_cls, BaseProcessor\n            ), \"All processors must inherit BaseProcessor class\"\n            if name in cls.mapping[\"processor_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"processor_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"processor_name_mapping\"][name] = processor_cls\n            return processor_cls\n\n        return wrap\n\n    @classmethod\n    def register_lr_scheduler(cls, name):\n        r\"\"\"Register a model to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n\n        def wrap(lr_sched_cls):\n            if name in cls.mapping[\"lr_scheduler_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"lr_scheduler_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"lr_scheduler_name_mapping\"][name] = lr_sched_cls\n            return lr_sched_cls\n\n        return wrap\n\n    @classmethod\n    def register_runner(cls, name):\n        r\"\"\"Register a model to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n\n        def wrap(runner_cls):\n            if name in cls.mapping[\"runner_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"runner_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"runner_name_mapping\"][name] = runner_cls\n            return runner_cls\n\n        return wrap\n\n    @classmethod\n    def register_path(cls, name, path):\n        r\"\"\"Register a path to registry with key 'name'\n\n        Args:\n            name: Key with which the path will be registered.\n\n        Usage:\n\n            from lavis.common.registry import registry\n        \"\"\"\n        assert isinstance(path, str), \"All path must be str.\"\n        if name in cls.mapping[\"paths\"]:\n            raise KeyError(\"Name '{}' already registered.\".format(name))\n        cls.mapping[\"paths\"][name] = path\n\n    @classmethod\n    def register(cls, name, obj):\n        r\"\"\"Register an item to registry with key 'name'\n\n        Args:\n            name: Key with which the item will be registered.\n\n        Usage::\n\n            from lavis.common.registry import registry\n\n            registry.register(\"config\", {})\n        \"\"\"\n        path = name.split(\".\")\n        current = cls.mapping[\"state\"]\n\n        for part in path[:-1]:\n            if part not in current:\n                current[part] = {}\n            current = current[part]\n\n        current[path[-1]] = obj\n\n    # @classmethod\n    # def get_trainer_class(cls, name):\n    #     return cls.mapping[\"trainer_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_builder_class(cls, name):\n        return cls.mapping[\"builder_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_model_class(cls, name):\n        return cls.mapping[\"model_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_task_class(cls, name):\n        return cls.mapping[\"task_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_processor_class(cls, name):\n        return cls.mapping[\"processor_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_lr_scheduler_class(cls, name):\n        return cls.mapping[\"lr_scheduler_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_runner_class(cls, name):\n        return cls.mapping[\"runner_name_mapping\"].get(name, None)\n\n    @classmethod\n    def list_runners(cls):\n        return sorted(cls.mapping[\"runner_name_mapping\"].keys())\n\n    @classmethod\n    def list_models(cls):\n        return sorted(cls.mapping[\"model_name_mapping\"].keys())\n\n    @classmethod\n    def list_tasks(cls):\n        return sorted(cls.mapping[\"task_name_mapping\"].keys())\n\n    @classmethod\n    def list_processors(cls):\n        return sorted(cls.mapping[\"processor_name_mapping\"].keys())\n\n    @classmethod\n    def list_lr_schedulers(cls):\n        return sorted(cls.mapping[\"lr_scheduler_name_mapping\"].keys())\n\n    @classmethod\n    def list_datasets(cls):\n        return sorted(cls.mapping[\"builder_name_mapping\"].keys())\n\n    @classmethod\n    def get_path(cls, name):\n        return cls.mapping[\"paths\"].get(name, None)\n\n    @classmethod\n    def get(cls, name, default=None, no_warning=False):\n        r\"\"\"Get an item from registry with key 'name'\n\n        Args:\n            name (string): Key whose value needs to be retrieved.\n            default: If passed and key is not in registry, default value will\n                     be returned with a warning. Default: None\n            no_warning (bool): If passed as True, warning when key doesn't exist\n                               will not be generated. Useful for MMF's\n                               internal operations. Default: False\n        \"\"\"\n        original_name = name\n        name = name.split(\".\")\n        value = cls.mapping[\"state\"]\n        for subname in name:\n            value = value.get(subname, default)\n            if value is default:\n                break\n\n        if (\n            \"writer\" in cls.mapping[\"state\"]\n            and value == default\n            and no_warning is False\n        ):\n            cls.mapping[\"state\"][\"writer\"].warning(\n                \"Key {} is not present in registry, returning default value \"\n                \"of {}\".format(original_name, default)\n            )\n        return value\n\n    @classmethod\n    def unregister(cls, name):\n        r\"\"\"Remove an item from registry with key 'name'\n\n        Args:\n            name: Key which needs to be removed.\n        Usage::\n\n            from mmf.common.registry import registry\n\n            config = registry.unregister(\"config\")\n        \"\"\"\n        return cls.mapping[\"state\"].pop(name, None)\n\n\nregistry = Registry()\n"
  },
  {
    "path": "lavis/common/utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport io\nimport json\nimport logging\nimport os\nimport pickle\nimport re\nimport shutil\nimport urllib\nimport urllib.error\nimport urllib.request\nfrom typing import Optional\nfrom urllib.parse import urlparse\n\nimport numpy as np\nimport pandas as pd\nimport yaml\nfrom iopath.common.download import download\nfrom iopath.common.file_io import file_lock, g_pathmgr\nfrom lavis.common.registry import registry\nfrom torch.utils.model_zoo import tqdm\nfrom torchvision.datasets.utils import (\n    check_integrity,\n    download_file_from_google_drive,\n    extract_archive,\n)\n\n\ndef now():\n    from datetime import datetime\n\n    return datetime.now().strftime(\"%Y%m%d%H%M\")[:-1]\n\n\ndef is_url(url_or_filename):\n    parsed = urlparse(url_or_filename)\n    return parsed.scheme in (\"http\", \"https\")\n\n\ndef get_cache_path(rel_path):\n    return os.path.expanduser(os.path.join(registry.get_path(\"cache_root\"), rel_path))\n\n\ndef get_abs_path(rel_path):\n    return os.path.join(registry.get_path(\"library_root\"), rel_path)\n\n\ndef load_json(filename):\n    with open(filename, \"r\") as f:\n        return json.load(f)\n\n\n# The following are adapted from torchvision and vissl\n# torchvision: https://github.com/pytorch/vision\n# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py\n\n\ndef makedir(dir_path):\n    \"\"\"\n    Create the directory if it does not exist.\n    \"\"\"\n    is_success = False\n    try:\n        if not g_pathmgr.exists(dir_path):\n            g_pathmgr.mkdirs(dir_path)\n        is_success = True\n    except BaseException:\n        print(f\"Error creating directory: {dir_path}\")\n    return is_success\n\n\ndef get_redirected_url(url: str):\n    \"\"\"\n    Given a URL, returns the URL it redirects to or the\n    original URL in case of no indirection\n    \"\"\"\n    import requests\n\n    with requests.Session() as session:\n        with session.get(url, stream=True, allow_redirects=True) as response:\n            if response.history:\n                return response.url\n            else:\n                return url\n\n\ndef to_google_drive_download_url(view_url: str) -> str:\n    \"\"\"\n    Utility function to transform a view URL of google drive\n    to a download URL for google drive\n    Example input:\n        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view\n    Example output:\n        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp\n    \"\"\"\n    splits = view_url.split(\"/\")\n    assert splits[-1] == \"view\"\n    file_id = splits[-2]\n    return f\"https://drive.google.com/uc?export=download&id={file_id}\"\n\n\ndef download_google_drive_url(url: str, output_path: str, output_file_name: str):\n    \"\"\"\n    Download a file from google drive\n    Downloading an URL from google drive requires confirmation when\n    the file of the size is too big (google drive notifies that\n    anti-viral checks cannot be performed on such files)\n    \"\"\"\n    import requests\n\n    with requests.Session() as session:\n\n        # First get the confirmation token and append it to the URL\n        with session.get(url, stream=True, allow_redirects=True) as response:\n            for k, v in response.cookies.items():\n                if k.startswith(\"download_warning\"):\n                    url = url + \"&confirm=\" + v\n\n        # Then download the content of the file\n        with session.get(url, stream=True, verify=True) as response:\n            makedir(output_path)\n            path = os.path.join(output_path, output_file_name)\n            total_size = int(response.headers.get(\"Content-length\", 0))\n            with open(path, \"wb\") as file:\n                from tqdm import tqdm\n\n                with tqdm(total=total_size) as progress_bar:\n                    for block in response.iter_content(\n                        chunk_size=io.DEFAULT_BUFFER_SIZE\n                    ):\n                        file.write(block)\n                        progress_bar.update(len(block))\n\n\ndef _get_google_drive_file_id(url: str) -> Optional[str]:\n    parts = urlparse(url)\n\n    if re.match(r\"(drive|docs)[.]google[.]com\", parts.netloc) is None:\n        return None\n\n    match = re.match(r\"/file/d/(?P<id>[^/]*)\", parts.path)\n    if match is None:\n        return None\n\n    return match.group(\"id\")\n\n\ndef _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:\n    with open(filename, \"wb\") as fh:\n        with urllib.request.urlopen(\n            urllib.request.Request(url, headers={\"User-Agent\": \"vissl\"})\n        ) as response:\n            with tqdm(total=response.length) as pbar:\n                for chunk in iter(lambda: response.read(chunk_size), \"\"):\n                    if not chunk:\n                        break\n                    pbar.update(chunk_size)\n                    fh.write(chunk)\n\n\ndef download_url(\n    url: str,\n    root: str,\n    filename: Optional[str] = None,\n    md5: Optional[str] = None,\n) -> None:\n    \"\"\"Download a file from a url and place it in root.\n    Args:\n        url (str): URL to download file from\n        root (str): Directory to place downloaded file in\n        filename (str, optional): Name to save the file under.\n                                  If None, use the basename of the URL.\n        md5 (str, optional): MD5 checksum of the download. If None, do not check\n    \"\"\"\n    root = os.path.expanduser(root)\n    if not filename:\n        filename = os.path.basename(url)\n    fpath = os.path.join(root, filename)\n\n    makedir(root)\n\n    # check if file is already present locally\n    if check_integrity(fpath, md5):\n        print(\"Using downloaded and verified file: \" + fpath)\n        return\n\n    # expand redirect chain if needed\n    url = get_redirected_url(url)\n\n    # check if file is located on Google Drive\n    file_id = _get_google_drive_file_id(url)\n    if file_id is not None:\n        return download_file_from_google_drive(file_id, root, filename, md5)\n\n    # download the file\n    try:\n        print(\"Downloading \" + url + \" to \" + fpath)\n        _urlretrieve(url, fpath)\n    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]\n        if url[:5] == \"https\":\n            url = url.replace(\"https:\", \"http:\")\n            print(\n                \"Failed download. Trying https -> http instead.\"\n                \" Downloading \" + url + \" to \" + fpath\n            )\n            _urlretrieve(url, fpath)\n        else:\n            raise e\n\n    # check integrity of downloaded file\n    if not check_integrity(fpath, md5):\n        raise RuntimeError(\"File not found or corrupted.\")\n\n\ndef download_and_extract_archive(\n    url: str,\n    download_root: str,\n    extract_root: Optional[str] = None,\n    filename: Optional[str] = None,\n    md5: Optional[str] = None,\n    remove_finished: bool = False,\n) -> None:\n    download_root = os.path.expanduser(download_root)\n    if extract_root is None:\n        extract_root = download_root\n    if not filename:\n        filename = os.path.basename(url)\n\n    download_url(url, download_root, filename, md5)\n\n    archive = os.path.join(download_root, filename)\n    print(\"Extracting {} to {}\".format(archive, extract_root))\n    extract_archive(archive, extract_root, remove_finished)\n\n\ndef cache_url(url: str, cache_dir: str) -> str:\n    \"\"\"\n    This implementation downloads the remote resource and caches it locally.\n    The resource will only be downloaded if not previously requested.\n    \"\"\"\n    parsed_url = urlparse(url)\n    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip(\"/\")))\n    makedir(dirname)\n    filename = url.split(\"/\")[-1]\n    cached = os.path.join(dirname, filename)\n    with file_lock(cached):\n        if not os.path.isfile(cached):\n            logging.info(f\"Downloading {url} to {cached} ...\")\n            cached = download(url, dirname, filename=filename)\n    logging.info(f\"URL {url} cached in {cached}\")\n    return cached\n\n\n# TODO (prigoyal): convert this into RAII-style API\ndef create_file_symlink(file1, file2):\n    \"\"\"\n    Simply create the symlinks for a given file1 to file2.\n    Useful during model checkpointing to symlinks to the\n    latest successful checkpoint.\n    \"\"\"\n    try:\n        if g_pathmgr.exists(file2):\n            g_pathmgr.rm(file2)\n        g_pathmgr.symlink(file1, file2)\n    except Exception as e:\n        logging.info(f\"Could NOT create symlink. Error: {e}\")\n\n\ndef save_file(data, filename, append_to_json=True, verbose=True):\n    \"\"\"\n    Common i/o utility to handle saving data to various file formats.\n    Supported:\n        .pkl, .pickle, .npy, .json\n    Specifically for .json, users have the option to either append (default)\n    or rewrite by passing in Boolean value to append_to_json.\n    \"\"\"\n    if verbose:\n        logging.info(f\"Saving data to file: {filename}\")\n    file_ext = os.path.splitext(filename)[1]\n    if file_ext in [\".pkl\", \".pickle\"]:\n        with g_pathmgr.open(filename, \"wb\") as fopen:\n            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)\n    elif file_ext == \".npy\":\n        with g_pathmgr.open(filename, \"wb\") as fopen:\n            np.save(fopen, data)\n    elif file_ext == \".json\":\n        if append_to_json:\n            with g_pathmgr.open(filename, \"a\") as fopen:\n                fopen.write(json.dumps(data, sort_keys=True) + \"\\n\")\n                fopen.flush()\n        else:\n            with g_pathmgr.open(filename, \"w\") as fopen:\n                fopen.write(json.dumps(data, sort_keys=True) + \"\\n\")\n                fopen.flush()\n    elif file_ext == \".yaml\":\n        with g_pathmgr.open(filename, \"w\") as fopen:\n            dump = yaml.dump(data)\n            fopen.write(dump)\n            fopen.flush()\n    else:\n        raise Exception(f\"Saving {file_ext} is not supported yet\")\n\n    if verbose:\n        logging.info(f\"Saved data to file: {filename}\")\n\n\ndef load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):\n    \"\"\"\n    Common i/o utility to handle loading data from various file formats.\n    Supported:\n        .pkl, .pickle, .npy, .json\n    For the npy files, we support reading the files in mmap_mode.\n    If the mmap_mode of reading is not successful, we load data without the\n    mmap_mode.\n    \"\"\"\n    if verbose:\n        logging.info(f\"Loading data from file: {filename}\")\n\n    file_ext = os.path.splitext(filename)[1]\n    if file_ext == \".txt\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = fopen.readlines()\n    elif file_ext in [\".pkl\", \".pickle\"]:\n        with g_pathmgr.open(filename, \"rb\") as fopen:\n            data = pickle.load(fopen, encoding=\"latin1\")\n    elif file_ext == \".npy\":\n        if mmap_mode:\n            try:\n                with g_pathmgr.open(filename, \"rb\") as fopen:\n                    data = np.load(\n                        fopen,\n                        allow_pickle=allow_pickle,\n                        encoding=\"latin1\",\n                        mmap_mode=mmap_mode,\n                    )\n            except ValueError as e:\n                logging.info(\n                    f\"Could not mmap {filename}: {e}. Trying without g_pathmgr\"\n                )\n                data = np.load(\n                    filename,\n                    allow_pickle=allow_pickle,\n                    encoding=\"latin1\",\n                    mmap_mode=mmap_mode,\n                )\n                logging.info(\"Successfully loaded without g_pathmgr\")\n            except Exception:\n                logging.info(\"Could not mmap without g_pathmgr. Trying without mmap\")\n                with g_pathmgr.open(filename, \"rb\") as fopen:\n                    data = np.load(fopen, allow_pickle=allow_pickle, encoding=\"latin1\")\n        else:\n            with g_pathmgr.open(filename, \"rb\") as fopen:\n                data = np.load(fopen, allow_pickle=allow_pickle, encoding=\"latin1\")\n    elif file_ext == \".json\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = json.load(fopen)\n    elif file_ext == \".yaml\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = yaml.load(fopen, Loader=yaml.FullLoader)\n    elif file_ext == \".csv\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = pd.read_csv(fopen)\n    else:\n        raise Exception(f\"Reading from {file_ext} is not supported yet\")\n    return data\n\n\ndef abspath(resource_path: str):\n    \"\"\"\n    Make a path absolute, but take into account prefixes like\n    \"http://\" or \"manifold://\"\n    \"\"\"\n    regex = re.compile(r\"^\\w+://\")\n    if regex.match(resource_path) is None:\n        return os.path.abspath(resource_path)\n    else:\n        return resource_path\n\n\ndef makedir(dir_path):\n    \"\"\"\n    Create the directory if it does not exist.\n    \"\"\"\n    is_success = False\n    try:\n        if not g_pathmgr.exists(dir_path):\n            g_pathmgr.mkdirs(dir_path)\n        is_success = True\n    except BaseException:\n        logging.info(f\"Error creating directory: {dir_path}\")\n    return is_success\n\n\ndef is_url(input_url):\n    \"\"\"\n    Check if an input string is a url. look for http(s):// and ignoring the case\n    \"\"\"\n    is_url = re.match(r\"^(?:http)s?://\", input_url, re.IGNORECASE) is not None\n    return is_url\n\n\ndef cleanup_dir(dir):\n    \"\"\"\n    Utility for deleting a directory. Useful for cleaning the storage space\n    that contains various training artifacts like checkpoints, data etc.\n    \"\"\"\n    if os.path.exists(dir):\n        logging.info(f\"Deleting directory: {dir}\")\n        shutil.rmtree(dir)\n    logging.info(f\"Deleted contents of directory: {dir}\")\n\n\ndef get_file_size(filename):\n    \"\"\"\n    Given a file, get the size of file in MB\n    \"\"\"\n    size_in_mb = os.path.getsize(filename) / float(1024**2)\n    return size_in_mb\n"
  },
  {
    "path": "lavis/common/vqa_tools/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n__author__ = \"aagrawal\"\n"
  },
  {
    "path": "lavis/common/vqa_tools/vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n__author__ = \"aagrawal\"\n__version__ = \"0.9\"\n\n# Interface for accessing the VQA dataset.\n\n# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:\n# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).\n\n# The following functions are defined:\n#  VQA        - VQA class that loads VQA annotation file and prepares data structures.\n#  getQuesIds - Get question ids that satisfy given filter conditions.\n#  getImgIds  - Get image ids that satisfy given filter conditions.\n#  loadQA     - Load questions and answers with the specified question ids.\n#  showQA     - Display the specified questions and answers.\n#  loadRes    - Load result file and create result object.\n\n# Help on each function can be accessed by: \"help(COCO.function)\"\n\nimport json\nimport datetime\nimport copy\n\n\nclass VQA:\n    def __init__(self, annotation_file=None, question_file=None):\n        \"\"\"\n        Constructor of VQA helper class for reading and visualizing questions and answers.\n        :param annotation_file (str): location of VQA annotation file\n        :return:\n        \"\"\"\n        # load dataset\n        self.dataset = {}\n        self.questions = {}\n        self.qa = {}\n        self.qqa = {}\n        self.imgToQA = {}\n        if not annotation_file == None and not question_file == None:\n            print(\"loading VQA annotations and questions into memory...\")\n            time_t = datetime.datetime.utcnow()\n            dataset = json.load(open(annotation_file, \"r\"))\n            questions = json.load(open(question_file, \"r\"))\n            self.dataset = dataset\n            self.questions = questions\n            self.createIndex()\n\n    def createIndex(self):\n        # create index\n        print(\"creating index...\")\n        imgToQA = {ann[\"image_id\"]: [] for ann in self.dataset[\"annotations\"]}\n        qa = {ann[\"question_id\"]: [] for ann in self.dataset[\"annotations\"]}\n        qqa = {ann[\"question_id\"]: [] for ann in self.dataset[\"annotations\"]}\n        for ann in self.dataset[\"annotations\"]:\n            imgToQA[ann[\"image_id\"]] += [ann]\n            qa[ann[\"question_id\"]] = ann\n        for ques in self.questions[\"questions\"]:\n            qqa[ques[\"question_id\"]] = ques\n        print(\"index created!\")\n\n        # create class members\n        self.qa = qa\n        self.qqa = qqa\n        self.imgToQA = imgToQA\n\n    def info(self):\n        \"\"\"\n        Print information about the VQA annotation file.\n        :return:\n        \"\"\"\n        for key, value in self.datset[\"info\"].items():\n            print(\"%s: %s\" % (key, value))\n\n    def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):\n        \"\"\"\n        Get question ids that satisfy given filter conditions. default skips that filter\n        :param \timgIds    (int array)   : get question ids for given imgs\n                        quesTypes (str array)   : get question ids for given question types\n                        ansTypes  (str array)   : get question ids for given answer types\n        :return:    ids   (int array)   : integer array of question ids\n        \"\"\"\n        imgIds = imgIds if type(imgIds) == list else [imgIds]\n        quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]\n        ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]\n\n        if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:\n            anns = self.dataset[\"annotations\"]\n        else:\n            if not len(imgIds) == 0:\n                anns = sum(\n                    [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],\n                    [],\n                )\n            else:\n                anns = self.dataset[\"annotations\"]\n            anns = (\n                anns\n                if len(quesTypes) == 0\n                else [ann for ann in anns if ann[\"question_type\"] in quesTypes]\n            )\n            anns = (\n                anns\n                if len(ansTypes) == 0\n                else [ann for ann in anns if ann[\"answer_type\"] in ansTypes]\n            )\n        ids = [ann[\"question_id\"] for ann in anns]\n        return ids\n\n    def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):\n        \"\"\"\n         Get image ids that satisfy given filter conditions. default skips that filter\n         :param quesIds   (int array)   : get image ids for given question ids\n        quesTypes (str array)   : get image ids for given question types\n        ansTypes  (str array)   : get image ids for given answer types\n         :return: ids     (int array)   : integer array of image ids\n        \"\"\"\n        quesIds = quesIds if type(quesIds) == list else [quesIds]\n        quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]\n        ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]\n\n        if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:\n            anns = self.dataset[\"annotations\"]\n        else:\n            if not len(quesIds) == 0:\n                anns = sum(\n                    [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []\n                )\n            else:\n                anns = self.dataset[\"annotations\"]\n            anns = (\n                anns\n                if len(quesTypes) == 0\n                else [ann for ann in anns if ann[\"question_type\"] in quesTypes]\n            )\n            anns = (\n                anns\n                if len(ansTypes) == 0\n                else [ann for ann in anns if ann[\"answer_type\"] in ansTypes]\n            )\n        ids = [ann[\"image_id\"] for ann in anns]\n        return ids\n\n    def loadQA(self, ids=[]):\n        \"\"\"\n        Load questions and answers with the specified question ids.\n        :param ids (int array)       : integer ids specifying question ids\n        :return: qa (object array)   : loaded qa objects\n        \"\"\"\n        if type(ids) == list:\n            return [self.qa[id] for id in ids]\n        elif type(ids) == int:\n            return [self.qa[ids]]\n\n    def showQA(self, anns):\n        \"\"\"\n        Display the specified annotations.\n        :param anns (array of object): annotations to display\n        :return: None\n        \"\"\"\n        if len(anns) == 0:\n            return 0\n        for ann in anns:\n            quesId = ann[\"question_id\"]\n            print(\"Question: %s\" % (self.qqa[quesId][\"question\"]))\n            for ans in ann[\"answers\"]:\n                print(\"Answer %d: %s\" % (ans[\"answer_id\"], ans[\"answer\"]))\n\n    def loadRes(self, resFile, quesFile):\n        \"\"\"\n        Load result file and return a result object.\n        :param   resFile (str)     : file name of result file\n        :return: res (obj)         : result api object\n        \"\"\"\n        res = VQA()\n        res.questions = json.load(open(quesFile))\n        res.dataset[\"info\"] = copy.deepcopy(self.questions[\"info\"])\n        res.dataset[\"task_type\"] = copy.deepcopy(self.questions[\"task_type\"])\n        res.dataset[\"data_type\"] = copy.deepcopy(self.questions[\"data_type\"])\n        res.dataset[\"data_subtype\"] = copy.deepcopy(self.questions[\"data_subtype\"])\n        res.dataset[\"license\"] = copy.deepcopy(self.questions[\"license\"])\n\n        print(\"Loading and preparing results...     \")\n        time_t = datetime.datetime.utcnow()\n        anns = json.load(open(resFile))\n        assert type(anns) == list, \"results is not an array of objects\"\n        annsQuesIds = [ann[\"question_id\"] for ann in anns]\n        assert set(annsQuesIds) == set(\n            self.getQuesIds()\n        ), \"Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.\"\n        for ann in anns:\n            quesId = ann[\"question_id\"]\n            if res.dataset[\"task_type\"] == \"Multiple Choice\":\n                assert (\n                    ann[\"answer\"] in self.qqa[quesId][\"multiple_choices\"]\n                ), \"predicted answer is not one of the multiple choices\"\n            qaAnn = self.qa[quesId]\n            ann[\"image_id\"] = qaAnn[\"image_id\"]\n            ann[\"question_type\"] = qaAnn[\"question_type\"]\n            ann[\"answer_type\"] = qaAnn[\"answer_type\"]\n        print(\n            \"DONE (t=%0.2fs)\" % ((datetime.datetime.utcnow() - time_t).total_seconds())\n        )\n\n        res.dataset[\"annotations\"] = anns\n        res.createIndex()\n        return res\n"
  },
  {
    "path": "lavis/common/vqa_tools/vqa_eval.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n# coding=utf-8\n\n__author__ = \"aagrawal\"\n\n# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:\n# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).\nimport sys\nimport re\n\n\nclass VQAEval:\n    def __init__(self, vqa=None, vqaRes=None, n=2):\n        self.n = n\n        self.accuracy = {}\n        self.evalQA = {}\n        self.evalQuesType = {}\n        self.evalAnsType = {}\n        self.vqa = vqa\n        self.vqaRes = vqaRes\n        if vqa is not None:\n            self.params = {\"question_id\": vqa.getQuesIds()}\n        self.contractions = {\n            \"aint\": \"ain't\",\n            \"arent\": \"aren't\",\n            \"cant\": \"can't\",\n            \"couldve\": \"could've\",\n            \"couldnt\": \"couldn't\",\n            \"couldn'tve\": \"couldn't've\",\n            \"couldnt've\": \"couldn't've\",\n            \"didnt\": \"didn't\",\n            \"doesnt\": \"doesn't\",\n            \"dont\": \"don't\",\n            \"hadnt\": \"hadn't\",\n            \"hadnt've\": \"hadn't've\",\n            \"hadn'tve\": \"hadn't've\",\n            \"hasnt\": \"hasn't\",\n            \"havent\": \"haven't\",\n            \"hed\": \"he'd\",\n            \"hed've\": \"he'd've\",\n            \"he'dve\": \"he'd've\",\n            \"hes\": \"he's\",\n            \"howd\": \"how'd\",\n            \"howll\": \"how'll\",\n            \"hows\": \"how's\",\n            \"Id've\": \"I'd've\",\n            \"I'dve\": \"I'd've\",\n            \"Im\": \"I'm\",\n            \"Ive\": \"I've\",\n            \"isnt\": \"isn't\",\n            \"itd\": \"it'd\",\n            \"itd've\": \"it'd've\",\n            \"it'dve\": \"it'd've\",\n            \"itll\": \"it'll\",\n            \"let's\": \"let's\",\n            \"maam\": \"ma'am\",\n            \"mightnt\": \"mightn't\",\n            \"mightnt've\": \"mightn't've\",\n            \"mightn'tve\": \"mightn't've\",\n            \"mightve\": \"might've\",\n            \"mustnt\": \"mustn't\",\n            \"mustve\": \"must've\",\n            \"neednt\": \"needn't\",\n            \"notve\": \"not've\",\n            \"oclock\": \"o'clock\",\n            \"oughtnt\": \"oughtn't\",\n            \"ow's'at\": \"'ow's'at\",\n            \"'ows'at\": \"'ow's'at\",\n            \"'ow'sat\": \"'ow's'at\",\n            \"shant\": \"shan't\",\n            \"shed've\": \"she'd've\",\n            \"she'dve\": \"she'd've\",\n            \"she's\": \"she's\",\n            \"shouldve\": \"should've\",\n            \"shouldnt\": \"shouldn't\",\n            \"shouldnt've\": \"shouldn't've\",\n            \"shouldn'tve\": \"shouldn't've\",\n            \"somebody'd\": \"somebodyd\",\n            \"somebodyd've\": \"somebody'd've\",\n            \"somebody'dve\": \"somebody'd've\",\n            \"somebodyll\": \"somebody'll\",\n            \"somebodys\": \"somebody's\",\n            \"someoned\": \"someone'd\",\n            \"someoned've\": \"someone'd've\",\n            \"someone'dve\": \"someone'd've\",\n            \"someonell\": \"someone'll\",\n            \"someones\": \"someone's\",\n            \"somethingd\": \"something'd\",\n            \"somethingd've\": \"something'd've\",\n            \"something'dve\": \"something'd've\",\n            \"somethingll\": \"something'll\",\n            \"thats\": \"that's\",\n            \"thered\": \"there'd\",\n            \"thered've\": \"there'd've\",\n            \"there'dve\": \"there'd've\",\n            \"therere\": \"there're\",\n            \"theres\": \"there's\",\n            \"theyd\": \"they'd\",\n            \"theyd've\": \"they'd've\",\n            \"they'dve\": \"they'd've\",\n            \"theyll\": \"they'll\",\n            \"theyre\": \"they're\",\n            \"theyve\": \"they've\",\n            \"twas\": \"'twas\",\n            \"wasnt\": \"wasn't\",\n            \"wed've\": \"we'd've\",\n            \"we'dve\": \"we'd've\",\n            \"weve\": \"we've\",\n            \"werent\": \"weren't\",\n            \"whatll\": \"what'll\",\n            \"whatre\": \"what're\",\n            \"whats\": \"what's\",\n            \"whatve\": \"what've\",\n            \"whens\": \"when's\",\n            \"whered\": \"where'd\",\n            \"wheres\": \"where's\",\n            \"whereve\": \"where've\",\n            \"whod\": \"who'd\",\n            \"whod've\": \"who'd've\",\n            \"who'dve\": \"who'd've\",\n            \"wholl\": \"who'll\",\n            \"whos\": \"who's\",\n            \"whove\": \"who've\",\n            \"whyll\": \"why'll\",\n            \"whyre\": \"why're\",\n            \"whys\": \"why's\",\n            \"wont\": \"won't\",\n            \"wouldve\": \"would've\",\n            \"wouldnt\": \"wouldn't\",\n            \"wouldnt've\": \"wouldn't've\",\n            \"wouldn'tve\": \"wouldn't've\",\n            \"yall\": \"y'all\",\n            \"yall'll\": \"y'all'll\",\n            \"y'allll\": \"y'all'll\",\n            \"yall'd've\": \"y'all'd've\",\n            \"y'alld've\": \"y'all'd've\",\n            \"y'all'dve\": \"y'all'd've\",\n            \"youd\": \"you'd\",\n            \"youd've\": \"you'd've\",\n            \"you'dve\": \"you'd've\",\n            \"youll\": \"you'll\",\n            \"youre\": \"you're\",\n            \"youve\": \"you've\",\n        }\n        self.manualMap = {\n            \"none\": \"0\",\n            \"zero\": \"0\",\n            \"one\": \"1\",\n            \"two\": \"2\",\n            \"three\": \"3\",\n            \"four\": \"4\",\n            \"five\": \"5\",\n            \"six\": \"6\",\n            \"seven\": \"7\",\n            \"eight\": \"8\",\n            \"nine\": \"9\",\n            \"ten\": \"10\",\n        }\n        self.articles = [\"a\", \"an\", \"the\"]\n\n        self.periodStrip = re.compile(\"(?!<=\\d)(\\.)(?!\\d)\")\n        self.commaStrip = re.compile(\"(\\d)(,)(\\d)\")\n        self.punct = [\n            \";\",\n            r\"/\",\n            \"[\",\n            \"]\",\n            '\"',\n            \"{\",\n            \"}\",\n            \"(\",\n            \")\",\n            \"=\",\n            \"+\",\n            \"\\\\\",\n            \"_\",\n            \"-\",\n            \">\",\n            \"<\",\n            \"@\",\n            \"`\",\n            \",\",\n            \"?\",\n            \"!\",\n        ]\n\n    def evaluate(self, quesIds=None):\n        if quesIds == None:\n            quesIds = [quesId for quesId in self.params[\"question_id\"]]\n        gts = {}\n        res = {}\n        for quesId in quesIds:\n            gts[quesId] = self.vqa.qa[quesId]\n            res[quesId] = self.vqaRes.qa[quesId]\n\n        # =================================================\n        # Compute accuracy\n        # =================================================\n        accQA = []\n        accQuesType = {}\n        accAnsType = {}\n        print(\"computing accuracy\")\n        step = 0\n        for quesId in quesIds:\n            resAns = res[quesId][\"answer\"]\n            resAns = resAns.replace(\"\\n\", \" \")\n            resAns = resAns.replace(\"\\t\", \" \")\n            resAns = resAns.strip()\n            resAns = self.processPunctuation(resAns)\n            resAns = self.processDigitArticle(resAns)\n            gtAcc = []\n            gtAnswers = [ans[\"answer\"] for ans in gts[quesId][\"answers\"]]\n            if len(set(gtAnswers)) > 1:\n                for ansDic in gts[quesId][\"answers\"]:\n                    ansDic[\"answer\"] = self.processPunctuation(ansDic[\"answer\"])\n            for gtAnsDatum in gts[quesId][\"answers\"]:\n                otherGTAns = [\n                    item for item in gts[quesId][\"answers\"] if item != gtAnsDatum\n                ]\n                matchingAns = [item for item in otherGTAns if item[\"answer\"] == resAns]\n                acc = min(1, float(len(matchingAns)) / 3)\n                gtAcc.append(acc)\n            quesType = gts[quesId][\"question_type\"]\n            ansType = gts[quesId][\"answer_type\"]\n            avgGTAcc = float(sum(gtAcc)) / len(gtAcc)\n            accQA.append(avgGTAcc)\n            if quesType not in accQuesType:\n                accQuesType[quesType] = []\n            accQuesType[quesType].append(avgGTAcc)\n            if ansType not in accAnsType:\n                accAnsType[ansType] = []\n            accAnsType[ansType].append(avgGTAcc)\n            self.setEvalQA(quesId, avgGTAcc)\n            self.setEvalQuesType(quesId, quesType, avgGTAcc)\n            self.setEvalAnsType(quesId, ansType, avgGTAcc)\n            if step % 100 == 0:\n                self.updateProgress(step / float(len(quesIds)))\n            step = step + 1\n\n        self.setAccuracy(accQA, accQuesType, accAnsType)\n        print(\"Done computing accuracy\")\n\n    def processPunctuation(self, inText):\n        outText = inText\n        for p in self.punct:\n            if (p + \" \" in inText or \" \" + p in inText) or (\n                re.search(self.commaStrip, inText) != None\n            ):\n                outText = outText.replace(p, \"\")\n            else:\n                outText = outText.replace(p, \" \")\n        outText = self.periodStrip.sub(\"\", outText, re.UNICODE)\n        return outText\n\n    def processDigitArticle(self, inText):\n        outText = []\n        tempText = inText.lower().split()\n        for word in tempText:\n            word = self.manualMap.setdefault(word, word)\n            if word not in self.articles:\n                outText.append(word)\n            else:\n                pass\n        for wordId, word in enumerate(outText):\n            if word in self.contractions:\n                outText[wordId] = self.contractions[word]\n        outText = \" \".join(outText)\n        return outText\n\n    def setAccuracy(self, accQA, accQuesType, accAnsType):\n        self.accuracy[\"overall\"] = round(100 * float(sum(accQA)) / len(accQA), self.n)\n        self.accuracy[\"perQuestionType\"] = {\n            quesType: round(\n                100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),\n                self.n,\n            )\n            for quesType in accQuesType\n        }\n        self.accuracy[\"perAnswerType\"] = {\n            ansType: round(\n                100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n\n            )\n            for ansType in accAnsType\n        }\n\n    def setEvalQA(self, quesId, acc):\n        self.evalQA[quesId] = round(100 * acc, self.n)\n\n    def setEvalQuesType(self, quesId, quesType, acc):\n        if quesType not in self.evalQuesType:\n            self.evalQuesType[quesType] = {}\n        self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)\n\n    def setEvalAnsType(self, quesId, ansType, acc):\n        if ansType not in self.evalAnsType:\n            self.evalAnsType[ansType] = {}\n        self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)\n\n    def updateProgress(self, progress):\n        barLength = 20\n        status = \"\"\n        if isinstance(progress, int):\n            progress = float(progress)\n        if not isinstance(progress, float):\n            progress = 0\n            status = \"error: progress var must be float\\r\\n\"\n        if progress < 0:\n            progress = 0\n            status = \"Halt...\\r\\n\"\n        if progress >= 1:\n            progress = 1\n            status = \"Done...\\r\\n\"\n        block = int(round(barLength * progress))\n        text = \"\\rFinshed Percent: [{0}] {1}% {2}\".format(\n            \"#\" * block + \"-\" * (barLength - block), int(progress * 100), status\n        )\n        sys.stdout.write(text)\n        sys.stdout.flush()\n"
  },
  {
    "path": "lavis/configs/datasets/aokvqa/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  aok_vqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json\n          storage:\n              - aokvqa/annotations/aokvqa_v1p0_train.json\n        val:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json\n          storage:\n              - aokvqa/annotations/aokvqa_v1p0_val.json\n              - aokvqa/annotations/specialized_vocab_train_lavis.json\n              # - aokvqa/annotations/large_vocab_train_lavis.json\n        test:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json\n          storage:\n              - aokvqa/annotations/aokvqa_v1p0_test.json\n              - aokvqa/annotations/specialized_vocab_train_lavis.json\n      images:\n          storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/avsd/defaults_dial.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  avsd_dialogue: # name of the dataset builder\n    dataset_card: dataset_card/avsd_dialogue.md \n    data_type: features #extracted features of videos (I3D, VGGish) # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_train.json\n          storage: avsd/annotations/train.json \n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_val.json\n          storage: avsd/annotations/val.json \n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_test.json\n          storage: avsd/annotations/test.json \n      features:\n        storage: avsd/features/ \n"
  },
  {
    "path": "lavis/configs/datasets/coco/defaults_cap.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    dataset_card: dataset_card/coco_caption.md\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json\n          md5: aa31ac474cf6250ebb81d18348a07ed8\n          storage: coco/annotations/coco_karpathy_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json\n          md5: b273847456ef5580e33713b1f7de52a0\n          storage:  coco/annotations/coco_karpathy_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json\n          md5: 3ff34b0ef2db02d01c37399f6a2a6cd1\n          storage: coco/annotations/coco_karpathy_test.json\n      images:\n        storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/coco/defaults_ret.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  coco_retrieval:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json\n          md5: aa31ac474cf6250ebb81d18348a07ed8\n          storage: coco/annotations/coco_karpathy_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json\n          md5: b273847456ef5580e33713b1f7de52a0\n          storage:  coco/annotations/coco_karpathy_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json\n          md5: 3ff34b0ef2db02d01c37399f6a2a6cd1\n          storage: coco/annotations/coco_karpathy_test.json\n      images:\n          storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/coco/defaults_vqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  coco_vqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json\n          storage:\n              - coco/annotations/vqa_train.json\n              - coco/annotations/vqa_val.json\n        val:\n          url:\n              # TODO make this order insensitive\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json\n          storage:\n              - coco/annotations/vqa_val_eval.json\n              - coco/annotations/answer_list.json\n              - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json\n              - coco/annotations/v2_mscoco_val2014_annotations.json\n        test:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json\n          storage:\n              - coco/annotations/vqa_test.json\n              - coco/annotations/answer_list.json\n      images:\n          storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/coco/eval_vqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  coco_vqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        val:\n          url:\n              # TODO make this order insensitive\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json\n          storage:\n              - coco/annotations/vqa_val_eval.json\n              - coco/annotations/answer_list.json\n              - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json\n              - coco/annotations/v2_mscoco_val2014_annotations.json\n      images:\n          storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/conceptual_caption/defaults_12m.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  conceptual_caption_12m:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - /export/home/workspace/datasets/cc12m.json\n          storage:\n              - conceptual_caption/annotations/cc12m.json\n      images:\n          storage: conceptual_caption/images_12m\n"
  },
  {
    "path": "lavis/configs/datasets/conceptual_caption/defaults_3m.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  conceptual_caption_3m:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - /export/home/workspace/datasets/cc3m.json\n          storage:\n              - conceptual_caption/annotations/cc3m.json\n      images:\n          storage: conceptual_caption/images\n"
  },
  {
    "path": "lavis/configs/datasets/didemo/defaults_ret.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  didemo_retrieval: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_train.json\n          storage: didemo/annotations/retrieval_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_val.json\n          storage: didemo/annotations/retrieval_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_test.json\n          storage: didemo/annotations/retrieval_test.json\n      videos:\n        storage: didemo/videos\n        # storage: /export/share/dongxuli/data/didemo_retrieval/videos\n"
  },
  {
    "path": "lavis/configs/datasets/flickr30k/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  flickr30k:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images\n\n    build_info:\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json\n          storage: flickr30k/annotations/train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json\n          storage: flickr30k/annotations/val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json\n          storage: flickr30k/annotations/test.json\n      images:\n          storage: flickr30k/images\n          # storage: /export/share/datasets/vision/flickr30k\n"
  },
  {
    "path": "lavis/configs/datasets/gqa/balanced_testdev.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  gqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json\n          storage:\n              - gqa/annotations/train_balanced_questions.json\n        val:\n          url:\n            - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json\n          storage:\n            - gqa/annotations/testdev_balanced_questions.json\n        test:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json\n          storage:\n              - gqa/annotations/test_balanced_questions.json\n      images:\n          storage: gqa/images/\n"
  },
  {
    "path": "lavis/configs/datasets/gqa/balanced_val.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  gqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json\n          storage:\n              - gqa/annotations/train_balanced_questions.json\n        val:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/val_balanced_questions.json\n          storage:\n              - gqa/annotations/val_balanced_questions.json\n        test:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json\n          storage:\n              - gqa/annotations/test_balanced_questions.json\n      images:\n          storage: gqa/images/\n"
  },
  {
    "path": "lavis/configs/datasets/gqa/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  gqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - /export/share/datasets/vision/GQA/questions1.2/train_all_questions/train_all_questions_0.json\n              - /export/share/datasets/vision/GQA/questions1.2/val_all_questions.json\n          storage:\n              - gqa/annotations/train_all_questions_0.json\n              - gqa/annotations/val_all_questions.json\n        val:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json\n          storage:\n              - aokvqa/annotations/aokvqa_v1p0_val.json\n              - aokvqa/annotations/large_vocab_train_lavis.json\n        test:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json\n          storage:\n              - aokvqa/annotations/aokvqa_v1p0_test.json\n              - aokvqa/annotations/large_vocab_train_lavis.json\n      images:\n          storage: gqa/images/\n"
  },
  {
    "path": "lavis/configs/datasets/how2qa/defaults_qa.yaml",
    "content": "datasets:\n  how2qa: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/how2qa/train.json\n          storage: /nas-ssd/shoubin/datasets/how2qa/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/how2qa/val.json\n          storage: /nas-ssd/shoubin/datasets/how2qa/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/how2qa/val.json\n          storage: /nas-ssd/shoubin/datasets/how2qa/val.json\n      videos:\n        storage: /nas-hdd/shoubin/how2qa/clips/"
  },
  {
    "path": "lavis/configs/datasets/imagenet/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  imagenet:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      splits: [\"val\"]\n      images:\n          storage: /export/share/datasets/vision/imagenet\n"
  },
  {
    "path": "lavis/configs/datasets/laion/defaults_2B_multi.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  laion2B_multi:\n\n    data_type: images\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      storage: /export/laion/laion2B-multi/part-00000/{00000..01743}.tar\n"
  },
  {
    "path": "lavis/configs/datasets/msrvtt/defaults_cap.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msrvtt_cap: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_train.json\n          storage: msrvtt/annotations/cap_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_val.json\n          storage: msrvtt/annotations/cap_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_test.json\n          storage: msrvtt/annotations/cap_test.json\n      videos:\n        storage: msrvtt/videos\n"
  },
  {
    "path": "lavis/configs/datasets/msrvtt/defaults_qa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msrvtt_qa: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_train.json\n          storage: msrvtt/annotations/qa_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_val.json\n          storage: msrvtt/annotations/qa_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_test.json\n          storage: msrvtt/annotations/qa_test.json\n        ans2label:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/train_ans2label.json\n          storage: msrvtt/annotations/qa_ans2label.json\n      videos:\n        storage: msrvtt/videos\n"
  },
  {
    "path": "lavis/configs/datasets/msrvtt/defaults_ret.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msrvtt_retrieval: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_train.json\n          storage: msrvtt/annotations/retrieval_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_val.json\n          storage: msrvtt/annotations/retrieval_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_test.json\n          storage: msrvtt/annotations/retrieval_test.json\n      videos:\n        storage: msrvtt/videos\n"
  },
  {
    "path": "lavis/configs/datasets/msrvttmc/defaults_qa.yaml",
    "content": "datasets:\n  msrvttmc: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      # no training data for this dataset\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n          storage: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n        val:\n          url: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n          storage: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n          storage: /nas-ssd/shoubin/datasets/msrvttmc/val.json\n      videos:\n        storage: /nas-hdd/tarbucket/terran/data/msrvtt/videos/all/"
  },
  {
    "path": "lavis/configs/datasets/msvd/defaults_cap.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msvd_cap: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_train.json\n          storage: msvd/annotations/cap_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_val.json\n          storage: msvd/annotations/cap_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_test.json\n          storage: msvd/annotations/cap_test.json\n      videos:\n        storage: msvd/videos\n"
  },
  {
    "path": "lavis/configs/datasets/msvd/defaults_qa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msvd_qa: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_train.json\n          storage: msvd/annotations/qa_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_val.json\n          storage: msvd/annotations/qa_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_test.json\n          storage: msvd/annotations/qa_test.json\n        ans2label:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/train_ans2label.json\n          storage: msvd/annotations/qa_ans2label.json\n      videos:\n        storage: msvd/videos\n\n      instance_id_key: question_id\n"
  },
  {
    "path": "lavis/configs/datasets/nextqa/defaults_qa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  nextqa: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/nextqa/train.json\n          storage: /nas-ssd/shoubin/datasets/nextqa/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/nextqa/val.json\n          storage: /nas-ssd/shoubin/datasets/nextqa/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/nextqa/val.json\n          storage: /nas-ssd/shoubin/datasets/nextqa/val.json\n      videos:\n        storage: /nas-hdd/shoubin/videos/vidor/videos/\n"
  },
  {
    "path": "lavis/configs/datasets/nlvr/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  nlvr:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_train.json\n          storage: nlvr/annotations/train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json\n          storage: nlvr/annotations/dev.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json\n          storage: nlvr/annotations/test.json\n      images:\n          storage: /export/share/datasets/vision/NLVR2/\n"
  },
  {
    "path": "lavis/configs/datasets/nocaps/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  nocaps: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json\n          storage:  nocaps/annotations/nocaps_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json\n          storage: nocaps/annotations/nocaps_test.json\n      images:\n        storage: nocaps/images\n        # storage: /export/share/datasets/vision/nocaps/\n"
  },
  {
    "path": "lavis/configs/datasets/okvqa/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  ok_vqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              # TODO make this order insensitive\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json\n              # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_train2014_questions.json\n              # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_train2014_annotations.json\n          storage:\n              - okvqa/annotations/okvqa_train.json\n              # - okvqa/annotations/OpenEnded_mscoco_train2014_questions.json\n              # - okvqa/annotations/mscoco_train2014_annotations.json\n        test:\n          url:\n              # TODO make this order insensitive\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json\n          storage:\n              - okvqa/annotations/vqa_val_eval.json\n              - okvqa/annotations/answer_list.json\n              - okvqa/annotations/OpenEnded_mscoco_val2014_questions.json\n              - okvqa/annotations/mscoco_val2014_annotations.json\n      images:\n          storage: coco/images/\n"
  },
  {
    "path": "lavis/configs/datasets/qvh/defaults.yaml",
    "content": "datasets:\n  qvh: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/qvh/train.json\n          storage: /nas-ssd/shoubin/datasets/qvh/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/qvh/val.json\n          storage: /nas-ssd/shoubin/datasets/qvh/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/qvh/test.json\n          storage: /nas-ssd/shoubin/datasets/qvh/test.json\n      videos:\n        storage: /nas-hdd/shoubin/qvh/videos/"
  },
  {
    "path": "lavis/configs/datasets/sbu_caption/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  sbu_caption:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url:\n              - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/sbu/sbu.json\n              # - /export/share/dongxuli/data/lavis/sbu/annotation/sbu.json\n          storage:\n              - sbu_captions/annotations/sbu.json\n      images:\n          storage: sbu_captions/images\n          # storage: /export/share/datasets/vision_language/sbu_resize\n"
  },
  {
    "path": "lavis/configs/datasets/snli_ve/defaults.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  snli_ve:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /export/share/dongxuli/data/lavis/snli/annotation/ve_train.json\n          storage: snli/annotations/ve_train.json\n        val:\n          url: /export/share/dongxuli/data/lavis/snli/annotation/ve_dev.json\n          storage: snli/annotations/ve_dev.json\n        test:\n          url: /export/share/dongxuli/data/lavis/snli/annotation/ve_test.json\n          storage: snli/annotations/ve_test.json\n      images:\n          storage: flickr30k/images/flickr30k-images\n          # storage: /export/share/datasets/vision/flickr30k/flickr30k-images\n"
  },
  {
    "path": "lavis/configs/datasets/star/defaults_qa.yaml",
    "content": "datasets:\n  star: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/star/train.json\n          storage: /nas-ssd/shoubin/datasets/star/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/star/val.json\n          storage: /nas-ssd/shoubin/datasets/star/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/star/val.json\n          storage: /nas-ssd/shoubin/datasets/star/val.json\n      videos:\n        storage: /nas-hdd/shoubin/videos/charades/Charades_v1_480/"
  },
  {
    "path": "lavis/configs/datasets/tvqa/defaults_qa.yaml",
    "content": "datasets:\n  tvqa: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/tvqa/train.json\n          storage: /nas-ssd/shoubin/datasets/tvqa/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/tvqa/val.json\n          storage: /nas-ssd/shoubin/datasets/tvqa/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/tvqa/val.json\n          storage: /nas-ssd/shoubin/datasets/tvqa/val.json\n      videos:\n        storage: /nas-hdd/shoubin/videos/tvqa/videos_3fps_with_audio/"
  },
  {
    "path": "lavis/configs/datasets/vatex/defaults_cap.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  msvd_cap: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_train.json\n          storage: vatex/annotations/cap_train.json\n        val:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_val.json\n          storage: vatex/annotations/cap_val.json\n        test:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_private_test.json\n          storage: vatex/annotations/cap_test.json\n      videos:\n        storage: /export/share/dongxuli/data/vatex\n"
  },
  {
    "path": "lavis/configs/datasets/vg/defaults_caption.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  vg_caption:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_caption.json\n          storage: vg/annotations/vg_caption.json\n      images:\n        storage: vg/images/\n"
  },
  {
    "path": "lavis/configs/datasets/vg/defaults_vqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\ndatasets:\n  vg_vqa:\n    # data_dir: ${env.data_dir}/datasets\n    data_type: images # [images|videos|features]\n\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_qa.json\n          storage: vg/annotations/vg_qa.json\n      images:\n        storage: vg/images/\n"
  },
  {
    "path": "lavis/configs/datasets/vlep/defaults_qa.yaml",
    "content": "datasets:\n  vlep: # name of the dataset builder\n    # data_dir: ${env.data_dir}/datasets\n    data_type: videos # [images|videos|features]\n    build_info:\n      # Be careful not to append minus sign (-) before split to avoid itemizing\n      annotations:\n        train:\n          url: /nas-ssd/shoubin/datasets/vlep/train.json\n          storage: /nas-ssd/shoubin/datasets/vlep/train.json\n        val:\n          url: /nas-ssd/shoubin/datasets/vlep/val.json\n          storage: /nas-ssd/shoubin/datasets/vlep/val.json\n        test:\n          url: /nas-ssd/shoubin/datasets/vlep/val.json\n          storage: /nas-ssd/shoubin/datasets/vlep/val.json\n      videos:\n        storage: /nas-hdd/shoubin/videos/charades/Charades_v1_480/"
  },
  {
    "path": "lavis/configs/default.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nenv:\n  # For default users\n  # cache_root: \"cache\"\n  # For internal use with persistent storage\n  cache_root: \"/nas-hdd/shoubin/pretrained_model/\"\n"
  },
  {
    "path": "lavis/configs/models/albef_classification_ve.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_classification\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_snli_ve_lavis.pt\"\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n\n  num_classes: 3\n\n  use_distill: True\n  momentum: 0.995\n  alpha: 0.4\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_layer_norm_epsilon: 1e-6\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n      eval:\n        name: \"blip_image_eval\"\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_feature_extractor.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_pretrain\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  image_size: 224\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0\n  vit_layer_norm_epsilon: 1e-6\n  vit_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\n  embed_dim: 256\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 224\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_nlvr.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_nlvr\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/pretrain_model_nlvr.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_nlvr_lavis.pt\"\n\n  num_classes: 2\n\n  use_distill: True\n  momentum: 0.995\n  alpha: 0.4\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_layer_norm_epsilon: 1e-6\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_pretrain_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_pretrain\n\n  load_pretrained: True\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  image_size: 224\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0\n  vit_layer_norm_epsilon: 1e-6\n  vit_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n  mlm_mask_prob: 0.15\n\n  embed_dim: 256\n  momentum: 0.995\n  alpha: 0.4\n  temp: 0.07\n\n  max_txt_len: 30\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_retrieval_coco.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_coco_retrieval_lavis.pt\"\n\n  queue_size: 65536\n\n  # vit encoder\n  vit_type: \"base\"\n  image_size: 384\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0\n  vit_layer_norm_epsilon: 1e-6\n  vit_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\n  embed_dim: 256\n  momentum: 0.995\n  alpha: 0.4\n  temp: 0.07\n  use_distill: True\n\n  max_txt_len: 30\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_retrieval_flickr.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n  finetuned: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_flickr_retrieval_lavis.pt\n\n  queue_size: 65536\n\n  # vit encoder\n  vit_type: \"base\"\n  image_size: 384\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0\n  vit_layer_norm_epsilon: 1e-6\n  vit_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\n  embed_dim: 256\n  momentum: 0.995\n  alpha: 0.4\n  temp: 0.07\n  use_distill: True\n\n  max_txt_len: 30\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/albef_vqav2.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_vqav2_lavis.pt\"\n\n  use_distill: True\n  momentum: 0.995\n  alpha: 0.4\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_layer_norm_epsilon: 1e-6\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config_albef.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_question\"\n      eval:\n        name: \"blip_question\"\n"
  },
  {
    "path": "lavis/configs/models/alpro_qa_msrvtt.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  num_classes: 1500\n\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msrvtt_qa.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt\"\n\n  timesformer:\n    n_frms: 16\n    image_size: 224\n\n    patch_size: 16\n    attn_drop_rate: 0.\n    drop_rate: 0.\n    drop_path_rate: 0.1\n\n    use_grad_ckpt: True\n    ckpt_layer: 12\n\n  # bert config\n  med_config_path: \"configs/models/bert_config_alpro.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"alpro_video_train\"\n        n_frms: 16\n        image_size: 224\n      eval:\n        name: \"alpro_video_eval\"\n        n_frms: 16\n        image_size: 224\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/alpro_qa_msvd.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  num_classes: 2423\n\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msvd_qa.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt\"\n\n  timesformer:\n    n_frms: 16\n    image_size: 224\n\n    patch_size: 16\n    attn_drop_rate: 0.\n    drop_rate: 0.\n    drop_path_rate: 0.1\n    use_grad_ckpt: True\n    ckpt_layer: 12\n\n  # bert config\n  med_config_path: \"configs/models/bert_config_alpro.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"alpro_video_train\"\n        n_frms: 16\n        image_size: 224\n      eval:\n        name: \"alpro_video_eval\"\n        n_frms: 16\n        image_size: 224\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/alpro_retrieval_didemo.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n\n  load_finetuned: True\n\n  finetuned: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_didemo_retrieval.pt\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt\"\n\n  timesformer:\n    n_frms: 8\n    image_size: 224\n\n    patch_size: 16\n    attn_drop_rate: 0.\n    drop_rate: 0.\n    drop_path_rate: 0.1\n    use_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/bert_config_alpro.json\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"alpro_video_eval\"\n        n_frms: 8\n        image_size: 224\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/alpro_retrieval_msrvtt.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msrvtt_retrieval.pt\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt\"\n\n  timesformer:\n    n_frms: 8\n    image_size: 224\n\n    patch_size: 16\n    attn_drop_rate: 0.\n    drop_rate: 0.\n    drop_path_rate: 0.1\n    use_grad_ckpt: False\n\n  # bert config\n  med_config_path: \"configs/models/bert_config_alpro.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"alpro_video_train\"\n        n_frms: 8\n        image_size: 224\n      eval:\n        name: \"alpro_video_eval\"\n        n_frms: 8\n        image_size: 224\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/bert_config.json",
    "content": "{\n  \"architectures\": [\n    \"BertModel\"\n  ],\n  \"attention_probs_dropout_prob\": 0.1,\n  \"hidden_act\": \"gelu\",\n  \"hidden_dropout_prob\": 0.1,\n  \"hidden_size\": 768,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 3072,\n  \"layer_norm_eps\": 1e-12,\n  \"max_position_embeddings\": 512,\n  \"model_type\": \"bert\",\n  \"num_attention_heads\": 12,\n  \"num_hidden_layers\": 12,\n  \"pad_token_id\": 0,\n  \"add_type_embeddings\": false,\n  \"vocab_size\": 30522,\n  \"encoder_width\": 768,\n  \"add_cross_attention\": true\n}"
  },
  {
    "path": "lavis/configs/models/bert_config_alpro.json",
    "content": "{\n  \"architectures\": [\n    \"BertModel\"\n  ],\n  \"attention_probs_dropout_prob\": 0.1,\n  \"hidden_act\": \"gelu\",\n  \"hidden_dropout_prob\": 0.1,\n  \"hidden_size\": 768,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 3072,\n  \"layer_norm_eps\": 1e-12,\n  \"max_position_embeddings\": 512,\n  \"model_type\": \"bert\",\n  \"num_attention_heads\": 12,\n  \"num_hidden_layers\": 12,\n  \"pad_token_id\": 0,\n  \"add_type_embeddings\": true,\n  \"type_vocab_size\": 2,\n  \"vocab_size\": 30522,\n  \"encoder_width\": 768,\n  \"add_cross_attention\": false,\n  \"fusion_layer\": 6\n}"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_caption_flant5xl.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: caption_coco_flant5xl\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_flant5xl.pth\"\n\n  # vit encoder\n  image_size: 364\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp32\"\n  freeze_vit: False\n\n  # Q-Former\n  num_query_token: 32\n\n  # T5\n  t5_model: \"google/flan-t5-xl\"\n\n  # generation configs\n  prompt: \"a photo of\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: caption_coco_opt2.7b\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt2.7b.pth\"\n\n  # vit encoder\n  image_size: 364\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp32\"\n  freeze_vit: False\n\n  # Q-Former\n  num_query_token: 32\n\n  # OPT\n  opt_model: \"facebook/opt-2.7b\"\n\n  # generation configs\n  prompt: \"a photo of\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: caption_coco_opt6.7b\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt6.7b.pth\"\n\n  # vit encoder\n  image_size: 364\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp32\"\n  freeze_vit: False\n\n  # Q-Former\n  num_query_token: 32\n\n  # OPT\n  opt_model: \"facebook/opt-6.7b\"\n\n  # generation configs\n  prompt: \"a photo of\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_coco.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: coco\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_finetune_coco.pth\"\n\n  # vit encoder\n  image_size: 364\n  drop_path_rate: 0\n  use_grad_checkpoint: True\n  vit_precision: \"fp32\"\n  freeze_vit: False\n\n  # Q-Former\n  num_query_token: 32\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_pretrain.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pretrain\n  load_finetuned: False\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth\"\n  finetuned: \"\"\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n\n  # Q-Former\n  num_query_token: 32\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pretrain_flant5xl\n  load_finetuned: False\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth\"\n  finetuned: \"\"\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n\n  # Q-Former\n  num_query_token: 32\n\n  # T5\n  t5_model: \"google/flan-t5-xl\"\n\n  # generation configs\n  prompt: \"\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pretrain_flant5xxl\n  load_finetuned: False\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\"\n  finetuned: \"\"\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n\n  # Q-Former\n  num_query_token: 32\n\n  # T5\n  t5_model: \"google/flan-t5-xxl\"\n\n  # generation configs\n  prompt: \"\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pretrain_opt2.7b\n  load_finetuned: False\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth\"\n  finetuned: \"\"\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n\n  # Q-Former\n  num_query_token: 32\n\n  # OPT\n  opt_model: \"facebook/opt-2.7b\"\n\n  # generation configs\n  prompt: \"\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pretrain_opt6.7b\n  load_finetuned: False\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth\"\n  finetuned: \"\"\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n\n  # Q-Former\n  num_query_token: 32\n\n  # OPT\n  opt_model: \"facebook/opt-6.7b\"\n\n  # generation configs\n  prompt: \"\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_caption_base_coco.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_coco_caption_base.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\n  # generation configs\n  prompt: \"a picture of \"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_caption_large_coco.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  load_finetuned: True\n\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth\"\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth\"\n\n  vit_type: \"large\"\n  vit_grad_ckpt: True\n  vit_ckpt_layer: 5\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_large_config.json\"\n\n  # generation configs\n  prompt: \"a picture of \"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_classification_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_classification\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth\"\n\n  use_distill: True\n  momentum: 0.995\n  alpha: 0.4\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n"
  },
  {
    "path": "lavis/configs/models/blip_feature_extractor_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_pretrain\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 224\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\n  embed_dim: 256\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 224\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_itm_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_image_text_matching\n\n  load_finetuned: True\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\n  embed_dim: 256\n\npreprocess:\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_itm_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_image_text_matching\n\n  load_finetuned: True\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\"\n\n  # vit encoder\n  vit_type: \"large\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_large_config.json\"\n\n  embed_dim: 256\n\npreprocess:\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_nlvr.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_nlvr\n  model_type: nlvr\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  num_classes: 2\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_layer_norm_epsilon: 1e-6\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_pretrain_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_pretrain\n\n  load_pretrained: True\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n\n  image_size: 224\n  alpha: 0.4\n\n  # bert config\n  med_config_path: \"configs/models/bert_config.json\"\n\n  embed_dim: 256\n\n  # generation configs\n  prompt: \"a picture of \"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_pretrain_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_pretrain\n\n  # vit encoder\n  vit_type: \"large\"\n  vit_grad_ckpt: True\n  vit_ckpt_layer: 5\n\n  image_size: 224\n\n  # bert config\n  med_config_path: \"configs/models/med_large_config.json\"\n\n  embed_dim: 256\n\n  # generation configs\n  prompt: \"a picture of \"\n"
  },
  {
    "path": "lavis/configs/models/blip_retrieval_coco.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_coco_retrieval.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  queue_size: 57600\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: True\n  vit_ckpt_layer: 4\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\n  embed_dim: 256\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_retrieval_flickr.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_flickr_retrieval.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  queue_size: 57600\n  alpha: 0.4\n\n  negative_all_rank: False\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: True\n  vit_ckpt_layer: 4\n\n  image_size: 384\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\n  embed_dim: 256\n\npreprocess:\n  vis_processor:\n      train:\n        name: \"blip_image_train\"\n        image_size: 384\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      train:\n        name: \"blip_caption\"\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/blip_vqa_aokvqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_aokvqa.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0.1\n\n  image_size: 480\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n"
  },
  {
    "path": "lavis/configs/models/blip_vqa_okvqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_okvqa.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0.1\n\n  image_size: 480\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n"
  },
  {
    "path": "lavis/configs/models/blip_vqav2.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  load_finetuned: True\n\n  finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth\"\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  # vit encoder\n  vit_type: \"base\"\n  vit_grad_ckpt: False\n  vit_ckpt_layer: 0\n  vit_drop_path_rate: 0.1\n\n  image_size: 480\n\n  # bert config\n  med_config_path: \"configs/models/med_config.json\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n"
  },
  {
    "path": "lavis/configs/models/clip/RN101-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/RN101.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/RN50-quickgelu.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/RN50.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/RN50x16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"layers\": [\n            6,\n            8,\n            18,\n            8\n        ],\n        \"width\": 96,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/RN50x4.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": [\n            4,\n            6,\n            10,\n            6\n        ],\n        \"width\": 80,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-16-plus-240.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 240,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-16-plus.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-32-plus-256.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-32-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-H-16.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-L-14-280.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 280,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-L-14-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-L-16-320.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 320,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-L-16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/ViT-g-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 40,\n        \"width\": 1408,\n        \"head_width\": 88,\n        \"mlp_ratio\": 4.3637,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-efficientnetv2_rw_s.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"efficientnetv2_rw_s\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"abs_attn\",\n        \"timm_proj\": \"\",\n        \"image_size\": 288\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-resnet50d.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"resnet50d\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"abs_attn\",\n        \"timm_proj\": \"\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-resnetaa50d.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"resnetaa50d\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"abs_attn\",\n        \"timm_proj\": \"\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-resnetblur50.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"resnetblur50\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"abs_attn\",\n        \"timm_proj\": \"\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-swin_base_patch4_window7_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"swin_base_patch4_window7_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-vit_base_patch16_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_base_patch16_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-vit_base_patch32_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_base_patch32_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip/timm-vit_small_patch16_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_small_patch16_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "lavis/configs/models/clip_resnet50.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: RN50\n\n  pretrained: openai\n"
  },
  {
    "path": "lavis/configs/models/clip_vit_base16.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-B-16\n\n  pretrained: openai\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"clip_image_eval\"\n        image_size: 224\n"
  },
  {
    "path": "lavis/configs/models/clip_vit_base32.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-B-32\n#   ['RN50',\n#  'RN50-quickgelu',\n#  'RN50x4',\n#  'RN50x16',\n#  'RN101',\n#  'RN101-quickgelu',\n#  'timm-efficientnetv2_rw_s',\n#  'timm-resnet50d',\n#  'timm-resnetaa50d',\n#  'timm-resnetblur50',\n#  'timm-swin_base_patch4_window7_224',\n#  'timm-vit_base_patch16_224',\n#  'timm-vit_base_patch32_224',\n#  'timm-vit_small_patch16_224',\n#  'ViT-B-16',\n#  'ViT-B-16-plus',\n#  'ViT-B-16-plus-240',\n#  'ViT-B-32',\n#  'ViT-B-32-plus-256',\n#  'ViT-B-32-quickgelu',\n#  'ViT-g-14',\n#  'ViT-H-14',\n#  'ViT-H-16',\n#  'ViT-L-14',\n#  'ViT-L-14-280',\n#  'ViT-L-14-336',\n#  'ViT-L-16',\n#  'ViT-L-16-320']\n\n  pretrained: openai\n  # \"openai\"\n  # following not available for all models\n  # \"yfcc15m\"\n  # \"cc12m\"\n  # \"laion400m_e31\"\n  # \"laion400m_e32\"\n  # \"laion400m_avg\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"clip_image_eval\"\n        image_size: 224\n"
  },
  {
    "path": "lavis/configs/models/clip_vit_large14.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-L-14\n#   ['RN50',\n#  'RN50-quickgelu',\n#  'RN50x4',\n#  'RN50x16',\n#  'RN101',\n#  'RN101-quickgelu',\n#  'timm-efficientnetv2_rw_s',\n#  'timm-resnet50d',\n#  'timm-resnetaa50d',\n#  'timm-resnetblur50',\n#  'timm-swin_base_patch4_window7_224',\n#  'timm-vit_base_patch16_224',\n#  'timm-vit_base_patch32_224',\n#  'timm-vit_small_patch16_224',\n#  'ViT-B-16',\n#  'ViT-B-16-plus',\n#  'ViT-B-16-plus-240',\n#  'ViT-B-32',\n#  'ViT-B-32-plus-256',\n#  'ViT-B-32-quickgelu',\n#  'ViT-g-14',\n#  'ViT-H-14',\n#  'ViT-H-16',\n#  'ViT-L-14',\n#  'ViT-L-14-280',\n#  'ViT-L-14-336',\n#  'ViT-L-16',\n#  'ViT-L-16-320']\n\n  pretrained: openai\n  # \"openai\"\n  # following not available for all models\n  # \"yfcc15m\"\n  # \"cc12m\"\n  # \"laion400m_e31\"\n  # \"laion400m_e32\"\n  # \"laion400m_avg\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"clip_image_eval\"\n        image_size: 224\n"
  },
  {
    "path": "lavis/configs/models/clip_vit_large14_336.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-L-14-336\n#   ['RN50',\n#  'RN50-quickgelu',\n#  'RN50x4',\n#  'RN50x16',\n#  'RN101',\n#  'RN101-quickgelu',\n#  'timm-efficientnetv2_rw_s',\n#  'timm-resnet50d',\n#  'timm-resnetaa50d',\n#  'timm-resnetblur50',\n#  'timm-swin_base_patch4_window7_224',\n#  'timm-vit_base_patch16_224',\n#  'timm-vit_base_patch32_224',\n#  'timm-vit_small_patch16_224',\n#  'ViT-B-16',\n#  'ViT-B-16-plus',\n#  'ViT-B-16-plus-240',\n#  'ViT-B-32',\n#  'ViT-B-32-plus-256',\n#  'ViT-B-32-quickgelu',\n#  'ViT-g-14',\n#  'ViT-H-14',\n#  'ViT-H-16',\n#  'ViT-L-14',\n#  'ViT-L-14-280',\n#  'ViT-L-14-336',\n#  'ViT-L-16',\n#  'ViT-L-16-320']\n\n  pretrained: openai\n  # \"openai\"\n  # following not available for all models\n  # \"yfcc15m\"\n  # \"cc12m\"\n  # \"laion400m_e31\"\n  # \"laion400m_e32\"\n  # \"laion400m_avg\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"clip_image_eval\"\n        image_size: 336\n"
  },
  {
    "path": "lavis/configs/models/gpt_dialogue_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: gpt_dialogue\n  # pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth\"\n  # pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth\"\n\n  len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens \n  \n  len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"gpt_video_ft\"\n        eval:\n          name: \"gpt_video_ft\"\n    text_processor:\n        train:\n          name: \"gpt_dialogue\"\n        eval:\n          name: \"gpt_dialogue\""
  },
  {
    "path": "lavis/configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: img2prompt_vqa\n  model_type: base\n\n  image_question_matching_model:\n    arch: blip_image_text_matching\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth\"\n\n    # vit encoder\n    vit_type: \"large\"\n    vit_grad_ckpt: False\n    vit_ckpt_layer: 0\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    embed_dim: 256\n\n  image_captioning_model:\n    arch: blip_caption\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth\"\n\n    vit_type: \"large\"\n    vit_grad_ckpt: True\n    vit_ckpt_layer: 5\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    # generation configs\n    prompt: \"a picture of \"\n\n  question_generation_moodel:\n    pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/img2prompt/T5_large_QG.pth\"\n\n\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/med_config.json",
    "content": "{\n  \"architectures\": [\n    \"BertModel\"\n  ],\n  \"attention_probs_dropout_prob\": 0.1,\n  \"hidden_act\": \"gelu\",\n  \"hidden_dropout_prob\": 0.1,\n  \"hidden_size\": 768,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 3072,\n  \"layer_norm_eps\": 1e-12,\n  \"max_position_embeddings\": 512,\n  \"model_type\": \"bert\",\n  \"num_attention_heads\": 12,\n  \"num_hidden_layers\": 12,\n  \"pad_token_id\": 0,\n  \"add_type_embeddings\": false,\n  \"vocab_size\": 30524,\n  \"encoder_width\": 768,\n  \"add_cross_attention\": true\n}"
  },
  {
    "path": "lavis/configs/models/med_config_albef.json",
    "content": "{\n  \"architectures\": [\n    \"BertModel\"\n  ],\n  \"attention_probs_dropout_prob\": 0.1,\n  \"hidden_act\": \"gelu\",\n  \"hidden_dropout_prob\": 0.1,\n  \"hidden_size\": 768,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 3072,\n  \"layer_norm_eps\": 1e-12,\n  \"max_position_embeddings\": 512,\n  \"model_type\": \"bert\",\n  \"num_attention_heads\": 12,\n  \"num_hidden_layers\": 12,\n  \"pad_token_id\": 0,\n  \"add_type_embeddings\": false,\n  \"vocab_size\": 30522,\n  \"encoder_width\": 768,\n  \"add_cross_attention\": true,\n  \"fusion_layer\": 6\n}"
  },
  {
    "path": "lavis/configs/models/med_large_config.json",
    "content": "{\n  \"architectures\": [\n    \"BertModel\"\n  ],\n  \"attention_probs_dropout_prob\": 0.1,\n  \"hidden_act\": \"gelu\",\n  \"hidden_dropout_prob\": 0.1,\n  \"hidden_size\": 768,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 3072,\n  \"layer_norm_eps\": 1e-12,\n  \"max_position_embeddings\": 512,\n  \"model_type\": \"bert\",\n  \"num_attention_heads\": 12,\n  \"num_hidden_layers\": 12,\n  \"pad_token_id\": 0,\n  \"add_type_embeddings\": false,\n  \"vocab_size\": 30524,\n  \"encoder_width\": 1024,\n  \"add_cross_attention\": true\n}"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/pnp_vqa_3b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: 3b\n\n  image_question_matching_model:\n    arch: blip_image_text_matching\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth\"\n\n    # vit encoder\n    vit_type: \"large\"\n    vit_grad_ckpt: False\n    vit_ckpt_layer: 0\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    embed_dim: 256\n\n  image_captioning_model:\n    arch: blip_caption\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth\"\n\n    vit_type: \"large\"\n    vit_grad_ckpt: True\n    vit_ckpt_layer: 5\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    # generation configs\n    prompt: \"a picture of \"\n\n  question_answering_model:\n    arch: pnp_unifiedqav2_fid\n\n    pretrained: \"allenai/unifiedqa-v2-t5-3b-1363200\"\n\n    t5_config_path: \"configs/models/pnp-vqa/unifiedqav2_3b_config.json\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/pnp_vqa_base.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: base\n\n  image_question_matching_model:\n    arch: blip_image_text_matching\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth\"\n\n    # vit encoder\n    vit_type: \"large\"\n    vit_grad_ckpt: False\n    vit_ckpt_layer: 0\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    embed_dim: 256\n\n  image_captioning_model:\n    arch: blip_caption\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth\"\n\n    vit_type: \"large\"\n    vit_grad_ckpt: True\n    vit_ckpt_layer: 5\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    # generation configs\n    prompt: \"a picture of \"\n  question_answering_model:\n    arch: pnp_unifiedqav2_fid\n\n    pretrained: \"allenai/unifiedqa-v2-t5-base-1363200\"\n\n    t5_config_path: \"configs/models/pnp-vqa/unifiedqav2_base_config.json\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/pnp_vqa_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: large\n\n  image_question_matching_model:\n    arch: blip_image_text_matching\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth\"\n\n    # vit encoder\n    vit_type: \"large\"\n    vit_grad_ckpt: False\n    vit_ckpt_layer: 0\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    embed_dim: 256\n\n  image_captioning_model:\n    arch: blip_caption\n    load_finetuned: True\n\n    finetuned: \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth\"\n\n    vit_type: \"large\"\n    vit_grad_ckpt: True\n    vit_ckpt_layer: 5\n\n    image_size: 384\n\n    # bert config\n    med_config_path: \"configs/models/med_large_config.json\"\n\n    # generation configs\n    prompt: \"a picture of \"\n\n  question_answering_model:\n    arch: pnp_unifiedqav2_fid\n\n    pretrained: \"allenai/unifiedqa-v2-t5-large-1363200\"\n\n    t5_config_path: \"configs/models/pnp-vqa/unifiedqav2_large_config.json\"\n\npreprocess:\n  vis_processor:\n      eval:\n        name: \"blip_image_eval\"\n        image_size: 384\n  text_processor:\n      eval:\n        name: \"blip_caption\"\n"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/unifiedqav2_3b_config.json",
    "content": "{\n  \"architectures\": [\n    \"T5ForConditionalGeneration\"\n  ],\n  \"d_ff\": 16384,\n  \"d_kv\": 128,\n  \"d_model\": 1024,\n  \"decoder_start_token_id\": 0,\n  \"dense_act_fn\": \"relu\",\n  \"dropout_rate\": 0.1,\n  \"eos_token_id\": 1,\n  \"feed_forward_proj\": \"relu\",\n  \"gradient_checkpointing\": false,\n  \"initializer_factor\": 1.0,\n  \"is_encoder_decoder\": true,\n  \"is_gated_act\": false,\n  \"layer_norm_epsilon\": 1e-06,\n  \"model_type\": \"t5\",\n  \"n_positions\": 512,\n  \"num_decoder_layers\": 24,\n  \"num_heads\": 32,\n  \"num_layers\": 24,\n  \"output_past\": true,\n  \"pad_token_id\": 0,\n  \"relative_attention_max_distance\": 128,\n  \"relative_attention_num_buckets\": 32,\n  \"task_specific_params\": {\n    \"summarization\": {\n      \"early_stopping\": true,\n      \"length_penalty\": 2.0,\n      \"max_length\": 200,\n      \"min_length\": 30,\n      \"no_repeat_ngram_size\": 3,\n      \"num_beams\": 4,\n      \"prefix\": \"summarize: \"\n    },\n    \"translation_en_to_de\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to German: \"\n    },\n    \"translation_en_to_fr\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to French: \"\n    },\n    \"translation_en_to_ro\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to Romanian: \"\n    }\n  },\n  \"torch_dtype\": \"float32\",\n  \"transformers_version\": \"4.21.3\",\n  \"use_cache\": true,\n  \"vocab_size\": 32128\n}"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/unifiedqav2_base_config.json",
    "content": "{\n  \"architectures\": [\n    \"T5ForConditionalGeneration\"\n  ],\n  \"d_ff\": 3072,\n  \"d_kv\": 64,\n  \"d_model\": 768,\n  \"decoder_start_token_id\": 0,\n  \"dense_act_fn\": \"relu\",\n  \"dropout_rate\": 0.1,\n  \"eos_token_id\": 1,\n  \"feed_forward_proj\": \"relu\",\n  \"gradient_checkpointing\": false,\n  \"initializer_factor\": 1.0,\n  \"is_encoder_decoder\": true,\n  \"is_gated_act\": false,\n  \"layer_norm_epsilon\": 1e-06,\n  \"model_type\": \"t5\",\n  \"n_positions\": 512,\n  \"num_decoder_layers\": 12,\n  \"num_heads\": 12,\n  \"num_layers\": 12,\n  \"output_past\": true,\n  \"pad_token_id\": 0,\n  \"relative_attention_max_distance\": 128,\n  \"relative_attention_num_buckets\": 32,\n  \"task_specific_params\": {\n    \"summarization\": {\n      \"early_stopping\": true,\n      \"length_penalty\": 2.0,\n      \"max_length\": 200,\n      \"min_length\": 30,\n      \"no_repeat_ngram_size\": 3,\n      \"num_beams\": 4,\n      \"prefix\": \"summarize: \"\n    },\n    \"translation_en_to_de\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to German: \"\n    },\n    \"translation_en_to_fr\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to French: \"\n    },\n    \"translation_en_to_ro\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to Romanian: \"\n    }\n  },\n  \"transformers_version\": \"4.21.3\",\n  \"use_cache\": true,\n  \"vocab_size\": 32128\n}"
  },
  {
    "path": "lavis/configs/models/pnp-vqa/unifiedqav2_large_config.json",
    "content": "{\n  \"architectures\": [\n    \"T5ForConditionalGeneration\"\n  ],\n  \"d_ff\": 4096,\n  \"d_kv\": 64,\n  \"d_model\": 1024,\n  \"decoder_start_token_id\": 0,\n  \"dense_act_fn\": \"relu\",\n  \"dropout_rate\": 0.1,\n  \"eos_token_id\": 1,\n  \"feed_forward_proj\": \"relu\",\n  \"gradient_checkpointing\": false,\n  \"initializer_factor\": 1.0,\n  \"is_encoder_decoder\": true,\n  \"is_gated_act\": false,\n  \"layer_norm_epsilon\": 1e-06,\n  \"model_type\": \"t5\",\n  \"n_positions\": 512,\n  \"num_decoder_layers\": 24,\n  \"num_heads\": 16,\n  \"num_layers\": 24,\n  \"output_past\": true,\n  \"pad_token_id\": 0,\n  \"relative_attention_max_distance\": 128,\n  \"relative_attention_num_buckets\": 32,\n  \"task_specific_params\": {\n    \"summarization\": {\n      \"early_stopping\": true,\n      \"length_penalty\": 2.0,\n      \"max_length\": 200,\n      \"min_length\": 30,\n      \"no_repeat_ngram_size\": 3,\n      \"num_beams\": 4,\n      \"prefix\": \"summarize: \"\n    },\n    \"translation_en_to_de\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to German: \"\n    },\n    \"translation_en_to_fr\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to French: \"\n    },\n    \"translation_en_to_ro\": {\n      \"early_stopping\": true,\n      \"max_length\": 300,\n      \"num_beams\": 4,\n      \"prefix\": \"translate English to Romanian: \"\n    }\n  },\n  \"transformers_version\": \"4.21.3\",\n  \"use_cache\": true,\n  \"vocab_size\": 32128\n}"
  },
  {
    "path": "lavis/configs/models/sevila.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: False\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth\"\n  finetuned: \"\"\n  use_grad_checkpoint: False\n  image_size: 224\n  drop_path_rate: 0\n  vit_precision: \"fp16\"\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 5\n  task: train_loc_freeze_qa_vid\n\n  # Q-Former\n  num_query_token: 32\n  # T5\n  t5_model: \"google/flan-t5-xl\"\n  # generation configs\n  prompt: \"\"\n\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 32\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 32\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n"
  },
  {
    "path": "lavis/datasets/builders/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.datasets.builders.base_dataset_builder import load_dataset_config\nfrom lavis.datasets.builders.caption_builder import (\n    COCOCapBuilder,\n    MSRVTTCapBuilder,\n    MSVDCapBuilder,\n    VATEXCapBuilder,\n)\nfrom lavis.datasets.builders.image_text_pair_builder import (\n    ConceptualCaption12MBuilder,\n    ConceptualCaption3MBuilder,\n    VGCaptionBuilder,\n    SBUCaptionBuilder,\n)\nfrom lavis.datasets.builders.classification_builder import (\n    NLVRBuilder,\n    SNLIVisualEntailmentBuilder,\n)\nfrom lavis.datasets.builders.imagefolder_builder import ImageNetBuilder\nfrom lavis.datasets.builders.video_qa_builder import (\n    MSRVTTQABuilder, MSVDQABuilder, MCVideoQABuilder, \n    NextQABuilder, STARBuilder, TVQABuilder, How2QABuilder, VLEPBuilder, QVHBuilder)\n\nfrom lavis.datasets.builders.vqa_builder import (\n    COCOVQABuilder,\n    OKVQABuilder,\n    VGVQABuilder,\n    GQABuilder,\n)\nfrom lavis.datasets.builders.retrieval_builder import (\n    MSRVTTRetrievalBuilder,\n    DiDeMoRetrievalBuilder,\n    COCORetrievalBuilder,\n    Flickr30kBuilder,\n)\nfrom lavis.datasets.builders.dialogue_builder import AVSDDialBuilder\n\nfrom lavis.common.registry import registry\n\n__all__ = [\n    \"COCOCapBuilder\",\n    \"COCORetrievalBuilder\",\n    \"COCOVQABuilder\",\n    \"ConceptualCaption12MBuilder\",\n    \"ConceptualCaption3MBuilder\",\n    \"DiDeMoRetrievalBuilder\",\n    \"Flickr30kBuilder\",\n    \"GQABuilder\",\n    \"ImageNetBuilder\",\n    \"MSRVTTCapBuilder\",\n    \"MSRVTTQABuilder\",\n    \"MSRVTTRetrievalBuilder\",\n    \"MSVDCapBuilder\",\n    \"MSVDQABuilder\",\n    \"NLVRBuilder\",\n    \"OKVQABuilder\",\n    \"SBUCaptionBuilder\",\n    \"SNLIVisualEntailmentBuilder\",\n    \"VATEXCapBuilder\",\n    \"VGCaptionBuilder\",\n    \"VGVQABuilder\",\n    \"AVSDDialBuilder\",\n    \"MCVideoQABuilder\",\n    \"NextQABuilder\",\n    \"STARBuilder\",\n    \"How2QABuilder\",\n    \"TVQABuilder\",\n    \"VLEPBuilder\",\n    \"QVHBuilder\"\n]\n\n\ndef load_dataset(name, cfg_path=None, vis_path=None, data_type=None):\n    \"\"\"\n    Example\n\n    >>> dataset = load_dataset(\"coco_caption\", cfg=None)\n    >>> splits = dataset.keys()\n    >>> print([len(dataset[split]) for split in splits])\n\n    \"\"\"\n    if cfg_path is None:\n        cfg = None\n    else:\n        cfg = load_dataset_config(cfg_path)\n\n    try:\n        builder = registry.get_builder_class(name)(cfg)\n    except TypeError:\n        print(\n            f\"Dataset {name} not found. Available datasets:\\n\"\n            + \", \".join([str(k) for k in dataset_zoo.get_names()])\n        )\n        exit(1)\n\n    if vis_path is not None:\n        if data_type is None:\n            # use default data type in the config\n            data_type = builder.config.data_type\n\n        assert (\n            data_type in builder.config.build_info\n        ), f\"Invalid data_type {data_type} for {name}.\"\n\n        builder.config.build_info.get(data_type).storage = vis_path\n\n    dataset = builder.build_datasets()\n    return dataset\n\n\nclass DatasetZoo:\n    def __init__(self) -> None:\n        self.dataset_zoo = {\n            k: list(v.DATASET_CONFIG_DICT.keys())\n            for k, v in sorted(registry.mapping[\"builder_name_mapping\"].items())\n        }\n\n    def get_names(self):\n        return list(self.dataset_zoo.keys())\n\n\ndataset_zoo = DatasetZoo()\n"
  },
  {
    "path": "lavis/datasets/builders/base_dataset_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\nimport shutil\nimport warnings\n\nimport lavis.common.utils as utils\nimport torch.distributed as dist\nfrom lavis.common.dist_utils import is_dist_avail_and_initialized, is_main_process\nfrom lavis.common.registry import registry\nfrom lavis.datasets.data_utils import extract_archive\nfrom lavis.processors.base_processor import BaseProcessor\nfrom omegaconf import OmegaConf\nfrom torchvision.datasets.utils import download_url\n\n\nclass BaseDatasetBuilder:\n    train_dataset_cls, eval_dataset_cls = None, None\n\n    def __init__(self, cfg=None):\n        super().__init__()\n\n        if cfg is None:\n            # help to create datasets from default config.\n            self.config = load_dataset_config(self.default_config_path())\n        elif isinstance(cfg, str):\n            self.config = load_dataset_config(cfg)\n        else:\n            # when called from task.build_dataset()\n            self.config = cfg\n\n        self.data_type = self.config.data_type\n        self.vis_processors = {\"train\": BaseProcessor(), \"eval\": BaseProcessor()}\n        self.text_processors = {\"train\": BaseProcessor(), \"eval\": BaseProcessor()}\n\n    def build_datasets(self):\n        # download, split, etc...\n        # only called on 1 GPU/TPU in distributed\n\n        if is_main_process():\n            self._download_data()\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.\n        logging.info(\"Building datasets...\")\n        datasets = self.build()  # dataset['train'/'val'/'test']\n\n        return datasets\n\n    def build_processors(self):\n        vis_proc_cfg = self.config.get(\"vis_processor\")\n        txt_proc_cfg = self.config.get(\"text_processor\")\n\n        if vis_proc_cfg is not None:\n            vis_train_cfg = vis_proc_cfg.get(\"train\")\n            vis_eval_cfg = vis_proc_cfg.get(\"eval\")\n\n            self.vis_processors[\"train\"] = self._build_proc_from_cfg(vis_train_cfg)\n            self.vis_processors[\"eval\"] = self._build_proc_from_cfg(vis_eval_cfg)\n\n        if txt_proc_cfg is not None:\n            txt_train_cfg = txt_proc_cfg.get(\"train\")\n            txt_eval_cfg = txt_proc_cfg.get(\"eval\")\n\n            self.text_processors[\"train\"] = self._build_proc_from_cfg(txt_train_cfg)\n            self.text_processors[\"eval\"] = self._build_proc_from_cfg(txt_eval_cfg)\n\n    @staticmethod\n    def _build_proc_from_cfg(cfg):\n        return (\n            registry.get_processor_class(cfg.name).from_config(cfg)\n            if cfg is not None\n            else None\n        )\n\n    @classmethod\n    def default_config_path(cls, type=\"default\"):\n        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])\n\n    def _download_data(self):\n        self._download_ann()\n        self._download_vis()\n\n    def _download_ann(self):\n        \"\"\"\n        Download annotation files if necessary.\n        All the vision-language datasets should have annotations of unified format.\n\n        storage_path can be:\n          (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.\n          (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.\n\n        Local annotation paths should be relative.\n        \"\"\"\n        anns = self.config.build_info.annotations\n\n        splits = anns.keys()\n\n        cache_root = registry.get_path(\"cache_root\")\n\n        for split in splits:\n            info = anns[split]\n\n            urls, storage_paths = info.get(\"url\", None), info.storage\n\n            if isinstance(urls, str):\n                urls = [urls]\n            if isinstance(storage_paths, str):\n                storage_paths = [storage_paths]\n\n            assert len(urls) == len(storage_paths)\n\n            for url_or_filename, storage_path in zip(urls, storage_paths):\n                # if storage_path is relative, make it full by prefixing with cache_root.\n                if not os.path.isabs(storage_path):\n                    storage_path = os.path.join(cache_root, storage_path)\n\n                dirname = os.path.dirname(storage_path)\n                if not os.path.exists(dirname):\n                    os.makedirs(dirname)\n\n                if os.path.isfile(url_or_filename):\n                    src, dst = url_or_filename, storage_path\n                    if not os.path.exists(dst):\n                        shutil.copyfile(src=src, dst=dst)\n                    else:\n                        logging.info(\"Using existing file {}.\".format(dst))\n                else:\n                    if os.path.isdir(storage_path):\n                        # if only dirname is provided, suffix with basename of URL.\n                        raise ValueError(\n                            \"Expecting storage_path to be a file path, got directory {}\".format(\n                                storage_path\n                            )\n                        )\n                    else:\n                        filename = os.path.basename(storage_path)\n\n                    download_url(url=url_or_filename, root=dirname, filename=filename)\n\n    def _download_vis(self):\n\n        storage_path = self.config.build_info.get(self.data_type).storage\n        storage_path = utils.get_cache_path(storage_path)\n\n        if not os.path.exists(storage_path):\n            warnings.warn(\n                f\"\"\"\n                The specified path {storage_path} for visual inputs does not exist.\n                Please provide a correct path to the visual inputs or\n                refer to datasets/download_scripts/README.md for downloading instructions.\n                \"\"\"\n            )\n\n    def build(self):\n        \"\"\"\n        Create by split datasets inheriting torch.utils.data.Datasets.\n\n        # build() can be dataset-specific. Overwrite to customize.\n        \"\"\"\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        ann_info = build_info.annotations\n        vis_info = build_info.get(self.data_type)\n\n        datasets = dict()\n        for split in ann_info.keys():\n            if split not in [\"train\", \"val\", \"test\"]:\n                continue\n\n            is_train = split == \"train\"\n\n            # processors\n            vis_processor = (\n                self.vis_processors[\"train\"]\n                if is_train\n                else self.vis_processors[\"eval\"]\n            )\n            text_processor = (\n                self.text_processors[\"train\"]\n                if is_train\n                else self.text_processors[\"eval\"]\n            )\n\n            # annotation path\n            ann_paths = ann_info.get(split).storage\n            if isinstance(ann_paths, str):\n                ann_paths = [ann_paths]\n\n            abs_ann_paths = []\n            for ann_path in ann_paths:\n                if not os.path.isabs(ann_path):\n                    ann_path = utils.get_cache_path(ann_path)\n                abs_ann_paths.append(ann_path)\n            ann_paths = abs_ann_paths\n\n            # visual data storage path\n            vis_path = vis_info.storage\n            #print('vis_path',vis_path)\n            if not os.path.isabs(vis_path):\n                # vis_path = os.path.join(utils.get_cache_path(), vis_path)\n                vis_path = utils.get_cache_path(vis_path)\n            #print('vis_path2', vis_path)\n            if not os.path.exists(vis_path):\n                warnings.warn(\"storage path {} does not exist.\".format(vis_path))\n\n            # create datasets\n            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls\n            datasets[split] = dataset_cls(\n                vis_processor=vis_processor,\n                text_processor=text_processor,\n                ann_paths=ann_paths,\n                vis_root=vis_path,\n            )\n\n        return datasets\n\n\ndef load_dataset_config(cfg_path):\n    cfg = OmegaConf.load(cfg_path).datasets\n    cfg = cfg[list(cfg.keys())[0]]\n\n    return cfg\n"
  },
  {
    "path": "lavis/datasets/builders/caption_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.coco_caption_datasets import (\n    COCOCapDataset,\n    COCOCapEvalDataset,\n    NoCapsEvalDataset,\n)\n\nfrom lavis.common.registry import registry\nfrom lavis.datasets.datasets.video_caption_datasets import (\n    VideoCaptionDataset,\n    VideoCaptionEvalDataset,\n)\n\n@registry.register_builder(\"coco_caption\")\nclass COCOCapBuilder(BaseDatasetBuilder):\n    train_dataset_cls = COCOCapDataset\n    eval_dataset_cls = COCOCapEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/coco/defaults_cap.yaml\",\n    }\n\n\n@registry.register_builder(\"nocaps\")\nclass COCOCapBuilder(BaseDatasetBuilder):\n    eval_dataset_cls = NoCapsEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/nocaps/defaults.yaml\",\n    }\n\n\n@registry.register_builder(\"msrvtt_caption\")\nclass MSRVTTCapBuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoCaptionDataset\n    eval_dataset_cls = VideoCaptionEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/msrvtt/defaults_cap.yaml\",\n    }\n\n\n@registry.register_builder(\"msvd_caption\")\nclass MSVDCapBuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoCaptionDataset\n    eval_dataset_cls = VideoCaptionEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/msvd/defaults_cap.yaml\",\n    }\n\n\n@registry.register_builder(\"vatex_caption\")\nclass VATEXCapBuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoCaptionDataset\n    eval_dataset_cls = VideoCaptionEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/vatex/defaults_cap.yaml\",\n    }\n\n"
  },
  {
    "path": "lavis/datasets/builders/classification_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.nlvr_datasets import NLVRDataset, NLVREvalDataset\nfrom lavis.datasets.datasets.snli_ve_datasets import SNLIVisualEntialmentDataset\n\n\n@registry.register_builder(\"nlvr\")\nclass NLVRBuilder(BaseDatasetBuilder):\n    train_dataset_cls = NLVRDataset\n    eval_dataset_cls = NLVREvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/nlvr/defaults.yaml\"}\n\n\n@registry.register_builder(\"snli_ve\")\nclass SNLIVisualEntailmentBuilder(BaseDatasetBuilder):\n    train_dataset_cls = SNLIVisualEntialmentDataset\n    eval_dataset_cls = SNLIVisualEntialmentDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/snli_ve/defaults.yaml\"}\n"
  },
  {
    "path": "lavis/datasets/builders/dialogue_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.avsd_dialogue_datasets import (\n    AVSDDialDataset,\n    AVSDDialEvalDataset,\n)\n\n\n@registry.register_builder(\"avsd_dialogue\")\nclass AVSDDialBuilder(BaseDatasetBuilder):\n    train_dataset_cls = AVSDDialDataset\n    eval_dataset_cls = AVSDDialEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/avsd/defaults_dial.yaml\"}\n"
  },
  {
    "path": "lavis/datasets/builders/image_text_pair_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom lavis.common.registry import registry\n\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.image_text_pair_datasets import ImageTextPairDataset\nfrom lavis.datasets.datasets.laion_dataset import LaionDataset\n\n\n@registry.register_builder(\"conceptual_caption_3m\")\nclass ConceptualCaption3MBuilder(BaseDatasetBuilder):\n    train_dataset_cls = ImageTextPairDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/conceptual_caption/defaults_3m.yaml\"\n    }\n\n\n@registry.register_builder(\"conceptual_caption_12m\")\nclass ConceptualCaption12MBuilder(BaseDatasetBuilder):\n    train_dataset_cls = ImageTextPairDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/conceptual_caption/defaults_12m.yaml\"\n    }\n\n\n@registry.register_builder(\"sbu_caption\")\nclass SBUCaptionBuilder(BaseDatasetBuilder):\n    train_dataset_cls = ImageTextPairDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/sbu_caption/defaults.yaml\"}\n\n\n@registry.register_builder(\"vg_caption\")\nclass VGCaptionBuilder(BaseDatasetBuilder):\n    train_dataset_cls = ImageTextPairDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/vg/defaults_caption.yaml\"}\n\n\n@registry.register_builder(\"laion2B_multi\")\nclass Laion2BMultiBuilder(BaseDatasetBuilder):\n    train_dataset_cls = LaionDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/laion/defaults_2B_multi.yaml\"}\n\n    def _download_ann(self):\n        pass\n\n    def _download_vis(self):\n        pass\n\n    def build(self):\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        datasets = dict()\n        split = \"train\"  # laion dataset only has train split\n\n        # create datasets\n        # [NOTE] return inner_datasets (wds.DataPipeline)\n        dataset_cls = self.train_dataset_cls\n        datasets[split] = dataset_cls(\n            vis_processor=self.vis_processors[split],\n            text_processor=self.text_processors[split],\n            location=build_info.storage,\n        ).inner_dataset\n\n        return datasets\n"
  },
  {
    "path": "lavis/datasets/builders/imagefolder_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\n\nfrom lavis.common.registry import registry\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.imagefolder_dataset import ImageFolderDataset\n\n\n@registry.register_builder(\"imagenet\")\nclass ImageNetBuilder(BaseDatasetBuilder):\n    train_dataset_cls = ImageFolderDataset\n    eval_dataset_cls = ImageFolderDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/imagenet/defaults.yaml\"}\n\n    def _download_ann(self):\n        pass\n\n    def build(self):\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        vis_info = build_info.get(self.data_type)\n\n        datasets = dict()\n        for split in build_info.splits:\n            assert split in [\n                \"train\",\n                \"val\",\n            ], \"Invalid split name {}, must be one of 'train', 'val' and 'test'.\"\n\n            is_train = split == \"train\"\n\n            vis_processor = (\n                self.vis_processors[\"train\"]\n                if is_train\n                else self.vis_processors[\"eval\"]\n            )\n\n            vis_path = os.path.join(vis_info.storage, split)\n\n            # create datasets\n            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls\n            datasets[split] = dataset_cls(\n                vis_processor=vis_processor,\n                vis_root=vis_path,\n                classnames=imagenet_classnames,\n            )\n\n        return datasets\n\n\nimagenet_classnames = [\n    \"tench\",\n    \"goldfish\",\n    \"great white shark\",\n    \"tiger shark\",\n    \"hammerhead shark\",\n    \"electric ray\",\n    \"stingray\",\n    \"rooster\",\n    \"hen\",\n    \"ostrich\",\n    \"brambling\",\n    \"goldfinch\",\n    \"house finch\",\n    \"junco\",\n    \"indigo bunting\",\n    \"American robin\",\n    \"bulbul\",\n    \"jay\",\n    \"magpie\",\n    \"chickadee\",\n    \"American dipper\",\n    \"kite (bird of prey)\",\n    \"bald eagle\",\n    \"vulture\",\n    \"great grey owl\",\n    \"fire salamander\",\n    \"smooth newt\",\n    \"newt\",\n    \"spotted salamander\",\n    \"axolotl\",\n    \"American bullfrog\",\n    \"tree frog\",\n    \"tailed frog\",\n    \"loggerhead sea turtle\",\n    \"leatherback sea turtle\",\n    \"mud turtle\",\n    \"terrapin\",\n    \"box turtle\",\n    \"banded gecko\",\n    \"green iguana\",\n    \"Carolina anole\",\n    \"desert grassland whiptail lizard\",\n    \"agama\",\n    \"frilled-necked lizard\",\n    \"alligator lizard\",\n    \"Gila monster\",\n    \"European green lizard\",\n    \"chameleon\",\n    \"Komodo dragon\",\n    \"Nile crocodile\",\n    \"American alligator\",\n    \"triceratops\",\n    \"worm snake\",\n    \"ring-necked snake\",\n    \"eastern hog-nosed snake\",\n    \"smooth green snake\",\n    \"kingsnake\",\n    \"garter snake\",\n    \"water snake\",\n    \"vine snake\",\n    \"night snake\",\n    \"boa constrictor\",\n    \"African rock python\",\n    \"Indian cobra\",\n    \"green mamba\",\n    \"sea snake\",\n    \"Saharan horned viper\",\n    \"eastern diamondback rattlesnake\",\n    \"sidewinder rattlesnake\",\n    \"trilobite\",\n    \"harvestman\",\n    \"scorpion\",\n    \"yellow garden spider\",\n    \"barn spider\",\n    \"European garden spider\",\n    \"southern black widow\",\n    \"tarantula\",\n    \"wolf spider\",\n    \"tick\",\n    \"centipede\",\n    \"black grouse\",\n    \"ptarmigan\",\n    \"ruffed grouse\",\n    \"prairie grouse\",\n    \"peafowl\",\n    \"quail\",\n    \"partridge\",\n    \"african grey parrot\",\n    \"macaw\",\n    \"sulphur-crested cockatoo\",\n    \"lorikeet\",\n    \"coucal\",\n    \"bee eater\",\n    \"hornbill\",\n    \"hummingbird\",\n    \"jacamar\",\n    \"toucan\",\n    \"duck\",\n    \"red-breasted merganser\",\n    \"goose\",\n    \"black swan\",\n    \"tusker\",\n    \"echidna\",\n    \"platypus\",\n    \"wallaby\",\n    \"koala\",\n    \"wombat\",\n    \"jellyfish\",\n    \"sea anemone\",\n    \"brain coral\",\n    \"flatworm\",\n    \"nematode\",\n    \"conch\",\n    \"snail\",\n    \"slug\",\n    \"sea slug\",\n    \"chiton\",\n    \"chambered nautilus\",\n    \"Dungeness crab\",\n    \"rock crab\",\n    \"fiddler crab\",\n    \"red king crab\",\n    \"American lobster\",\n    \"spiny lobster\",\n    \"crayfish\",\n    \"hermit crab\",\n    \"isopod\",\n    \"white stork\",\n    \"black stork\",\n    \"spoonbill\",\n    \"flamingo\",\n    \"little blue heron\",\n    \"great egret\",\n    \"bittern bird\",\n    \"crane bird\",\n    \"limpkin\",\n    \"common gallinule\",\n    \"American coot\",\n    \"bustard\",\n    \"ruddy turnstone\",\n    \"dunlin\",\n    \"common redshank\",\n    \"dowitcher\",\n    \"oystercatcher\",\n    \"pelican\",\n    \"king penguin\",\n    \"albatross\",\n    \"grey whale\",\n    \"killer whale\",\n    \"dugong\",\n    \"sea lion\",\n    \"Chihuahua\",\n    \"Japanese Chin\",\n    \"Maltese\",\n    \"Pekingese\",\n    \"Shih Tzu\",\n    \"King Charles Spaniel\",\n    \"Papillon\",\n    \"toy terrier\",\n    \"Rhodesian Ridgeback\",\n    \"Afghan Hound\",\n    \"Basset Hound\",\n    \"Beagle\",\n    \"Bloodhound\",\n    \"Bluetick Coonhound\",\n    \"Black and Tan Coonhound\",\n    \"Treeing Walker Coonhound\",\n    \"English foxhound\",\n    \"Redbone Coonhound\",\n    \"borzoi\",\n    \"Irish Wolfhound\",\n    \"Italian Greyhound\",\n    \"Whippet\",\n    \"Ibizan Hound\",\n    \"Norwegian Elkhound\",\n    \"Otterhound\",\n    \"Saluki\",\n    \"Scottish Deerhound\",\n    \"Weimaraner\",\n    \"Staffordshire Bull Terrier\",\n    \"American Staffordshire Terrier\",\n    \"Bedlington Terrier\",\n    \"Border Terrier\",\n    \"Kerry Blue Terrier\",\n    \"Irish Terrier\",\n    \"Norfolk Terrier\",\n    \"Norwich Terrier\",\n    \"Yorkshire Terrier\",\n    \"Wire Fox Terrier\",\n    \"Lakeland Terrier\",\n    \"Sealyham Terrier\",\n    \"Airedale Terrier\",\n    \"Cairn Terrier\",\n    \"Australian Terrier\",\n    \"Dandie Dinmont Terrier\",\n    \"Boston Terrier\",\n    \"Miniature Schnauzer\",\n    \"Giant Schnauzer\",\n    \"Standard Schnauzer\",\n    \"Scottish Terrier\",\n    \"Tibetan Terrier\",\n    \"Australian Silky Terrier\",\n    \"Soft-coated Wheaten Terrier\",\n    \"West Highland White Terrier\",\n    \"Lhasa Apso\",\n    \"Flat-Coated Retriever\",\n    \"Curly-coated Retriever\",\n    \"Golden Retriever\",\n    \"Labrador Retriever\",\n    \"Chesapeake Bay Retriever\",\n    \"German Shorthaired Pointer\",\n    \"Vizsla\",\n    \"English Setter\",\n    \"Irish Setter\",\n    \"Gordon Setter\",\n    \"Brittany dog\",\n    \"Clumber Spaniel\",\n    \"English Springer Spaniel\",\n    \"Welsh Springer Spaniel\",\n    \"Cocker Spaniel\",\n    \"Sussex Spaniel\",\n    \"Irish Water Spaniel\",\n    \"Kuvasz\",\n    \"Schipperke\",\n    \"Groenendael dog\",\n    \"Malinois\",\n    \"Briard\",\n    \"Australian Kelpie\",\n    \"Komondor\",\n    \"Old English Sheepdog\",\n    \"Shetland Sheepdog\",\n    \"collie\",\n    \"Border Collie\",\n    \"Bouvier des Flandres dog\",\n    \"Rottweiler\",\n    \"German Shepherd Dog\",\n    \"Dobermann\",\n    \"Miniature Pinscher\",\n    \"Greater Swiss Mountain Dog\",\n    \"Bernese Mountain Dog\",\n    \"Appenzeller Sennenhund\",\n    \"Entlebucher Sennenhund\",\n    \"Boxer\",\n    \"Bullmastiff\",\n    \"Tibetan Mastiff\",\n    \"French Bulldog\",\n    \"Great Dane\",\n    \"St. Bernard\",\n    \"husky\",\n    \"Alaskan Malamute\",\n    \"Siberian Husky\",\n    \"Dalmatian\",\n    \"Affenpinscher\",\n    \"Basenji\",\n    \"pug\",\n    \"Leonberger\",\n    \"Newfoundland dog\",\n    \"Great Pyrenees dog\",\n    \"Samoyed\",\n    \"Pomeranian\",\n    \"Chow Chow\",\n    \"Keeshond\",\n    \"brussels griffon\",\n    \"Pembroke Welsh Corgi\",\n    \"Cardigan Welsh Corgi\",\n    \"Toy Poodle\",\n    \"Miniature Poodle\",\n    \"Standard Poodle\",\n    \"Mexican hairless dog (xoloitzcuintli)\",\n    \"grey wolf\",\n    \"Alaskan tundra wolf\",\n    \"red wolf or maned wolf\",\n    \"coyote\",\n    \"dingo\",\n    \"dhole\",\n    \"African wild dog\",\n    \"hyena\",\n    \"red fox\",\n    \"kit fox\",\n    \"Arctic fox\",\n    \"grey fox\",\n    \"tabby cat\",\n    \"tiger cat\",\n    \"Persian cat\",\n    \"Siamese cat\",\n    \"Egyptian Mau\",\n    \"cougar\",\n    \"lynx\",\n    \"leopard\",\n    \"snow leopard\",\n    \"jaguar\",\n    \"lion\",\n    \"tiger\",\n    \"cheetah\",\n    \"brown bear\",\n    \"American black bear\",\n    \"polar bear\",\n    \"sloth bear\",\n    \"mongoose\",\n    \"meerkat\",\n    \"tiger beetle\",\n    \"ladybug\",\n    \"ground beetle\",\n    \"longhorn beetle\",\n    \"leaf beetle\",\n    \"dung beetle\",\n    \"rhinoceros beetle\",\n    \"weevil\",\n    \"fly\",\n    \"bee\",\n    \"ant\",\n    \"grasshopper\",\n    \"cricket insect\",\n    \"stick insect\",\n    \"cockroach\",\n    \"praying mantis\",\n    \"cicada\",\n    \"leafhopper\",\n    \"lacewing\",\n    \"dragonfly\",\n    \"damselfly\",\n    \"red admiral butterfly\",\n    \"ringlet butterfly\",\n    \"monarch butterfly\",\n    \"small white butterfly\",\n    \"sulphur butterfly\",\n    \"gossamer-winged butterfly\",\n    \"starfish\",\n    \"sea urchin\",\n    \"sea cucumber\",\n    \"cottontail rabbit\",\n    \"hare\",\n    \"Angora rabbit\",\n    \"hamster\",\n    \"porcupine\",\n    \"fox squirrel\",\n    \"marmot\",\n    \"beaver\",\n    \"guinea pig\",\n    \"common sorrel horse\",\n    \"zebra\",\n    \"pig\",\n    \"wild boar\",\n    \"warthog\",\n    \"hippopotamus\",\n    \"ox\",\n    \"water buffalo\",\n    \"bison\",\n    \"ram (adult male sheep)\",\n    \"bighorn sheep\",\n    \"Alpine ibex\",\n    \"hartebeest\",\n    \"impala (antelope)\",\n    \"gazelle\",\n    \"arabian camel\",\n    \"llama\",\n    \"weasel\",\n    \"mink\",\n    \"European polecat\",\n    \"black-footed ferret\",\n    \"otter\",\n    \"skunk\",\n    \"badger\",\n    \"armadillo\",\n    \"three-toed sloth\",\n    \"orangutan\",\n    \"gorilla\",\n    \"chimpanzee\",\n    \"gibbon\",\n    \"siamang\",\n    \"guenon\",\n    \"patas monkey\",\n    \"baboon\",\n    \"macaque\",\n    \"langur\",\n    \"black-and-white colobus\",\n    \"proboscis monkey\",\n    \"marmoset\",\n    \"white-headed capuchin\",\n    \"howler monkey\",\n    \"titi monkey\",\n    \"Geoffroy's spider monkey\",\n    \"common squirrel monkey\",\n    \"ring-tailed lemur\",\n    \"indri\",\n    \"Asian elephant\",\n    \"African bush elephant\",\n    \"red panda\",\n    \"giant panda\",\n    \"snoek fish\",\n    \"eel\",\n    \"silver salmon\",\n    \"rock beauty fish\",\n    \"clownfish\",\n    \"sturgeon\",\n    \"gar fish\",\n    \"lionfish\",\n    \"pufferfish\",\n    \"abacus\",\n    \"abaya\",\n    \"academic gown\",\n    \"accordion\",\n    \"acoustic guitar\",\n    \"aircraft carrier\",\n    \"airliner\",\n    \"airship\",\n    \"altar\",\n    \"ambulance\",\n    \"amphibious vehicle\",\n    \"analog clock\",\n    \"apiary\",\n    \"apron\",\n    \"trash can\",\n    \"assault rifle\",\n    \"backpack\",\n    \"bakery\",\n    \"balance beam\",\n    \"balloon\",\n    \"ballpoint pen\",\n    \"Band-Aid\",\n    \"banjo\",\n    \"baluster / handrail\",\n    \"barbell\",\n    \"barber chair\",\n    \"barbershop\",\n    \"barn\",\n    \"barometer\",\n    \"barrel\",\n    \"wheelbarrow\",\n    \"baseball\",\n    \"basketball\",\n    \"bassinet\",\n    \"bassoon\",\n    \"swimming cap\",\n    \"bath towel\",\n    \"bathtub\",\n    \"station wagon\",\n    \"lighthouse\",\n    \"beaker\",\n    \"military hat (bearskin or shako)\",\n    \"beer bottle\",\n    \"beer glass\",\n    \"bell tower\",\n    \"baby bib\",\n    \"tandem bicycle\",\n    \"bikini\",\n    \"ring binder\",\n    \"binoculars\",\n    \"birdhouse\",\n    \"boathouse\",\n    \"bobsleigh\",\n    \"bolo tie\",\n    \"poke bonnet\",\n    \"bookcase\",\n    \"bookstore\",\n    \"bottle cap\",\n    \"hunting bow\",\n    \"bow tie\",\n    \"brass memorial plaque\",\n    \"bra\",\n    \"breakwater\",\n    \"breastplate\",\n    \"broom\",\n    \"bucket\",\n    \"buckle\",\n    \"bulletproof vest\",\n    \"high-speed train\",\n    \"butcher shop\",\n    \"taxicab\",\n    \"cauldron\",\n    \"candle\",\n    \"cannon\",\n    \"canoe\",\n    \"can opener\",\n    \"cardigan\",\n    \"car mirror\",\n    \"carousel\",\n    \"tool kit\",\n    \"cardboard box / carton\",\n    \"car wheel\",\n    \"automated teller machine\",\n    \"cassette\",\n    \"cassette player\",\n    \"castle\",\n    \"catamaran\",\n    \"CD player\",\n    \"cello\",\n    \"mobile phone\",\n    \"chain\",\n    \"chain-link fence\",\n    \"chain mail\",\n    \"chainsaw\",\n    \"storage chest\",\n    \"chiffonier\",\n    \"bell or wind chime\",\n    \"china cabinet\",\n    \"Christmas stocking\",\n    \"church\",\n    \"movie theater\",\n    \"cleaver\",\n    \"cliff dwelling\",\n    \"cloak\",\n    \"clogs\",\n    \"cocktail shaker\",\n    \"coffee mug\",\n    \"coffeemaker\",\n    \"spiral or coil\",\n    \"combination lock\",\n    \"computer keyboard\",\n    \"candy store\",\n    \"container ship\",\n    \"convertible\",\n    \"corkscrew\",\n    \"cornet\",\n    \"cowboy boot\",\n    \"cowboy hat\",\n    \"cradle\",\n    \"construction crane\",\n    \"crash helmet\",\n    \"crate\",\n    \"infant bed\",\n    \"Crock Pot\",\n    \"croquet ball\",\n    \"crutch\",\n    \"cuirass\",\n    \"dam\",\n    \"desk\",\n    \"desktop computer\",\n    \"rotary dial telephone\",\n    \"diaper\",\n    \"digital clock\",\n    \"digital watch\",\n    \"dining table\",\n    \"dishcloth\",\n    \"dishwasher\",\n    \"disc brake\",\n    \"dock\",\n    \"dog sled\",\n    \"dome\",\n    \"doormat\",\n    \"drilling rig\",\n    \"drum\",\n    \"drumstick\",\n    \"dumbbell\",\n    \"Dutch oven\",\n    \"electric fan\",\n    \"electric guitar\",\n    \"electric locomotive\",\n    \"entertainment center\",\n    \"envelope\",\n    \"espresso machine\",\n    \"face powder\",\n    \"feather boa\",\n    \"filing cabinet\",\n    \"fireboat\",\n    \"fire truck\",\n    \"fire screen\",\n    \"flagpole\",\n    \"flute\",\n    \"folding chair\",\n    \"football helmet\",\n    \"forklift\",\n    \"fountain\",\n    \"fountain pen\",\n    \"four-poster bed\",\n    \"freight car\",\n    \"French horn\",\n    \"frying pan\",\n    \"fur coat\",\n    \"garbage truck\",\n    \"gas mask or respirator\",\n    \"gas pump\",\n    \"goblet\",\n    \"go-kart\",\n    \"golf ball\",\n    \"golf cart\",\n    \"gondola\",\n    \"gong\",\n    \"gown\",\n    \"grand piano\",\n    \"greenhouse\",\n    \"radiator grille\",\n    \"grocery store\",\n    \"guillotine\",\n    \"hair clip\",\n    \"hair spray\",\n    \"half-track\",\n    \"hammer\",\n    \"hamper\",\n    \"hair dryer\",\n    \"hand-held computer\",\n    \"handkerchief\",\n    \"hard disk drive\",\n    \"harmonica\",\n    \"harp\",\n    \"combine harvester\",\n    \"hatchet\",\n    \"holster\",\n    \"home theater\",\n    \"honeycomb\",\n    \"hook\",\n    \"hoop skirt\",\n    \"gymnastic horizontal bar\",\n    \"horse-drawn vehicle\",\n    \"hourglass\",\n    \"iPod\",\n    \"clothes iron\",\n    \"carved pumpkin\",\n    \"jeans\",\n    \"jeep\",\n    \"T-shirt\",\n    \"jigsaw puzzle\",\n    \"rickshaw\",\n    \"joystick\",\n    \"kimono\",\n    \"knee pad\",\n    \"knot\",\n    \"lab coat\",\n    \"ladle\",\n    \"lampshade\",\n    \"laptop computer\",\n    \"lawn mower\",\n    \"lens cap\",\n    \"letter opener\",\n    \"library\",\n    \"lifeboat\",\n    \"lighter\",\n    \"limousine\",\n    \"ocean liner\",\n    \"lipstick\",\n    \"slip-on shoe\",\n    \"lotion\",\n    \"music speaker\",\n    \"loupe magnifying glass\",\n    \"sawmill\",\n    \"magnetic compass\",\n    \"messenger bag\",\n    \"mailbox\",\n    \"tights\",\n    \"one-piece bathing suit\",\n    \"manhole cover\",\n    \"maraca\",\n    \"marimba\",\n    \"mask\",\n    \"matchstick\",\n    \"maypole\",\n    \"maze\",\n    \"measuring cup\",\n    \"medicine cabinet\",\n    \"megalith\",\n    \"microphone\",\n    \"microwave oven\",\n    \"military uniform\",\n    \"milk can\",\n    \"minibus\",\n    \"miniskirt\",\n    \"minivan\",\n    \"missile\",\n    \"mitten\",\n    \"mixing bowl\",\n    \"mobile home\",\n    \"ford model t\",\n    \"modem\",\n    \"monastery\",\n    \"monitor\",\n    \"moped\",\n    \"mortar and pestle\",\n    \"graduation cap\",\n    \"mosque\",\n    \"mosquito net\",\n    \"vespa\",\n    \"mountain bike\",\n    \"tent\",\n    \"computer mouse\",\n    \"mousetrap\",\n    \"moving van\",\n    \"muzzle\",\n    \"metal nail\",\n    \"neck brace\",\n    \"necklace\",\n    \"baby pacifier\",\n    \"notebook computer\",\n    \"obelisk\",\n    \"oboe\",\n    \"ocarina\",\n    \"odometer\",\n    \"oil filter\",\n    \"pipe organ\",\n    \"oscilloscope\",\n    \"overskirt\",\n    \"bullock cart\",\n    \"oxygen mask\",\n    \"product packet / packaging\",\n    \"paddle\",\n    \"paddle wheel\",\n    \"padlock\",\n    \"paintbrush\",\n    \"pajamas\",\n    \"palace\",\n    \"pan flute\",\n    \"paper towel\",\n    \"parachute\",\n    \"parallel bars\",\n    \"park bench\",\n    \"parking meter\",\n    \"railroad car\",\n    \"patio\",\n    \"payphone\",\n    \"pedestal\",\n    \"pencil case\",\n    \"pencil sharpener\",\n    \"perfume\",\n    \"Petri dish\",\n    \"photocopier\",\n    \"plectrum\",\n    \"Pickelhaube\",\n    \"picket fence\",\n    \"pickup truck\",\n    \"pier\",\n    \"piggy bank\",\n    \"pill bottle\",\n    \"pillow\",\n    \"ping-pong ball\",\n    \"pinwheel\",\n    \"pirate ship\",\n    \"drink pitcher\",\n    \"block plane\",\n    \"planetarium\",\n    \"plastic bag\",\n    \"plate rack\",\n    \"farm plow\",\n    \"plunger\",\n    \"Polaroid camera\",\n    \"pole\",\n    \"police van\",\n    \"poncho\",\n    \"pool table\",\n    \"soda bottle\",\n    \"plant pot\",\n    \"potter's wheel\",\n    \"power drill\",\n    \"prayer rug\",\n    \"printer\",\n    \"prison\",\n    \"missile\",\n    \"projector\",\n    \"hockey puck\",\n    \"punching bag\",\n    \"purse\",\n    \"quill\",\n    \"quilt\",\n    \"race car\",\n    \"racket\",\n    \"radiator\",\n    \"radio\",\n    \"radio telescope\",\n    \"rain barrel\",\n    \"recreational vehicle\",\n    \"fishing casting reel\",\n    \"reflex camera\",\n    \"refrigerator\",\n    \"remote control\",\n    \"restaurant\",\n    \"revolver\",\n    \"rifle\",\n    \"rocking chair\",\n    \"rotisserie\",\n    \"eraser\",\n    \"rugby ball\",\n    \"ruler measuring stick\",\n    \"sneaker\",\n    \"safe\",\n    \"safety pin\",\n    \"salt shaker\",\n    \"sandal\",\n    \"sarong\",\n    \"saxophone\",\n    \"scabbard\",\n    \"weighing scale\",\n    \"school bus\",\n    \"schooner\",\n    \"scoreboard\",\n    \"CRT monitor\",\n    \"screw\",\n    \"screwdriver\",\n    \"seat belt\",\n    \"sewing machine\",\n    \"shield\",\n    \"shoe store\",\n    \"shoji screen / room divider\",\n    \"shopping basket\",\n    \"shopping cart\",\n    \"shovel\",\n    \"shower cap\",\n    \"shower curtain\",\n    \"ski\",\n    \"balaclava ski mask\",\n    \"sleeping bag\",\n    \"slide rule\",\n    \"sliding door\",\n    \"slot machine\",\n    \"snorkel\",\n    \"snowmobile\",\n    \"snowplow\",\n    \"soap dispenser\",\n    \"soccer ball\",\n    \"sock\",\n    \"solar thermal collector\",\n    \"sombrero\",\n    \"soup bowl\",\n    \"keyboard space bar\",\n    \"space heater\",\n    \"space shuttle\",\n    \"spatula\",\n    \"motorboat\",\n    \"spider web\",\n    \"spindle\",\n    \"sports car\",\n    \"spotlight\",\n    \"stage\",\n    \"steam locomotive\",\n    \"through arch bridge\",\n    \"steel drum\",\n    \"stethoscope\",\n    \"scarf\",\n    \"stone wall\",\n    \"stopwatch\",\n    \"stove\",\n    \"strainer\",\n    \"tram\",\n    \"stretcher\",\n    \"couch\",\n    \"stupa\",\n    \"submarine\",\n    \"suit\",\n    \"sundial\",\n    \"sunglasses\",\n    \"sunglasses\",\n    \"sunscreen\",\n    \"suspension bridge\",\n    \"mop\",\n    \"sweatshirt\",\n    \"swim trunks / shorts\",\n    \"swing\",\n    \"electrical switch\",\n    \"syringe\",\n    \"table lamp\",\n    \"tank\",\n    \"tape player\",\n    \"teapot\",\n    \"teddy bear\",\n    \"television\",\n    \"tennis ball\",\n    \"thatched roof\",\n    \"front curtain\",\n    \"thimble\",\n    \"threshing machine\",\n    \"throne\",\n    \"tile roof\",\n    \"toaster\",\n    \"tobacco shop\",\n    \"toilet seat\",\n    \"torch\",\n    \"totem pole\",\n    \"tow truck\",\n    \"toy store\",\n    \"tractor\",\n    \"semi-trailer truck\",\n    \"tray\",\n    \"trench coat\",\n    \"tricycle\",\n    \"trimaran\",\n    \"tripod\",\n    \"triumphal arch\",\n    \"trolleybus\",\n    \"trombone\",\n    \"hot tub\",\n    \"turnstile\",\n    \"typewriter keyboard\",\n    \"umbrella\",\n    \"unicycle\",\n    \"upright piano\",\n    \"vacuum cleaner\",\n    \"vase\",\n    \"vaulted or arched ceiling\",\n    \"velvet fabric\",\n    \"vending machine\",\n    \"vestment\",\n    \"viaduct\",\n    \"violin\",\n    \"volleyball\",\n    \"waffle iron\",\n    \"wall clock\",\n    \"wallet\",\n    \"wardrobe\",\n    \"military aircraft\",\n    \"sink\",\n    \"washing machine\",\n    \"water bottle\",\n    \"water jug\",\n    \"water tower\",\n    \"whiskey jug\",\n    \"whistle\",\n    \"hair wig\",\n    \"window screen\",\n    \"window shade\",\n    \"Windsor tie\",\n    \"wine bottle\",\n    \"airplane wing\",\n    \"wok\",\n    \"wooden spoon\",\n    \"wool\",\n    \"split-rail fence\",\n    \"shipwreck\",\n    \"sailboat\",\n    \"yurt\",\n    \"website\",\n    \"comic book\",\n    \"crossword\",\n    \"traffic or street sign\",\n    \"traffic light\",\n    \"dust jacket\",\n    \"menu\",\n    \"plate\",\n    \"guacamole\",\n    \"consomme\",\n    \"hot pot\",\n    \"trifle\",\n    \"ice cream\",\n    \"popsicle\",\n    \"baguette\",\n    \"bagel\",\n    \"pretzel\",\n    \"cheeseburger\",\n    \"hot dog\",\n    \"mashed potatoes\",\n    \"cabbage\",\n    \"broccoli\",\n    \"cauliflower\",\n    \"zucchini\",\n    \"spaghetti squash\",\n    \"acorn squash\",\n    \"butternut squash\",\n    \"cucumber\",\n    \"artichoke\",\n    \"bell pepper\",\n    \"cardoon\",\n    \"mushroom\",\n    \"Granny Smith apple\",\n    \"strawberry\",\n    \"orange\",\n    \"lemon\",\n    \"fig\",\n    \"pineapple\",\n    \"banana\",\n    \"jackfruit\",\n    \"cherimoya (custard apple)\",\n    \"pomegranate\",\n    \"hay\",\n    \"carbonara\",\n    \"chocolate syrup\",\n    \"dough\",\n    \"meatloaf\",\n    \"pizza\",\n    \"pot pie\",\n    \"burrito\",\n    \"red wine\",\n    \"espresso\",\n    \"tea cup\",\n    \"eggnog\",\n    \"mountain\",\n    \"bubble\",\n    \"cliff\",\n    \"coral reef\",\n    \"geyser\",\n    \"lakeshore\",\n    \"promontory\",\n    \"sandbar\",\n    \"beach\",\n    \"valley\",\n    \"volcano\",\n    \"baseball player\",\n    \"bridegroom\",\n    \"scuba diver\",\n    \"rapeseed\",\n    \"daisy\",\n    \"yellow lady's slipper\",\n    \"corn\",\n    \"acorn\",\n    \"rose hip\",\n    \"horse chestnut seed\",\n    \"coral fungus\",\n    \"agaric\",\n    \"gyromitra\",\n    \"stinkhorn mushroom\",\n    \"earth star fungus\",\n    \"hen of the woods mushroom\",\n    \"bolete\",\n    \"corn cob\",\n    \"toilet paper\",\n]\n"
  },
  {
    "path": "lavis/datasets/builders/retrieval_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.retrieval_datasets import (\n    RetrievalDataset,\n    RetrievalEvalDataset,\n    VideoRetrievalDataset,\n    VideoRetrievalEvalDataset,\n)\n\nfrom lavis.common.registry import registry\n\n\n@registry.register_builder(\"msrvtt_retrieval\")\nclass MSRVTTRetrievalBuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoRetrievalDataset\n    eval_dataset_cls = VideoRetrievalEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/msrvtt/defaults_ret.yaml\"}\n\n\n@registry.register_builder(\"didemo_retrieval\")\nclass DiDeMoRetrievalBuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoRetrievalDataset\n    eval_dataset_cls = VideoRetrievalEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/didemo/defaults_ret.yaml\"}\n\n\n@registry.register_builder(\"coco_retrieval\")\nclass COCORetrievalBuilder(BaseDatasetBuilder):\n    train_dataset_cls = RetrievalDataset\n    eval_dataset_cls = RetrievalEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/coco/defaults_ret.yaml\"}\n\n\n@registry.register_builder(\"flickr30k\")\nclass Flickr30kBuilder(BaseDatasetBuilder):\n    train_dataset_cls = RetrievalDataset\n    eval_dataset_cls = RetrievalEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/flickr30k/defaults.yaml\"}\n"
  },
  {
    "path": "lavis/datasets/builders/video_qa_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_cache_path\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom lavis.datasets.datasets.video_vqa_datasets import VideoQADataset\nfrom lavis.datasets.datasets.mc_video_vqa_datasets import MCVideoQADataset\n\nclass VideoQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = VideoQADataset\n    eval_dataset_cls = VideoQADataset\n\n    def build(self):\n        datasets = super().build()\n\n        ans2label = self.config.build_info.annotations.get(\"ans2label\")\n        if ans2label is None:\n            raise ValueError(\"ans2label is not specified in build_info.\")\n\n        ans2label = get_cache_path(ans2label.storage)\n\n        for split in datasets:\n            datasets[split]._build_class_labels(ans2label)\n\n        return datasets\n\nclass MCVideoQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = MCVideoQADataset\n    eval_dataset_cls = MCVideoQADataset\n\n    def build(self):\n        datasets = super().build()\n\n        for split in datasets:\n            datasets[split]._load_auxiliary_mappings()\n\n        return datasets\n\n@registry.register_builder(\"msrvtt_qa\")\nclass MSRVTTQABuilder(VideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/msrvtt/defaults_qa.yaml\",\n    }\n\n\n@registry.register_builder(\"msvd_qa\")\nclass MSVDQABuilder(VideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/msvd/defaults_qa.yaml\",\n    }\n\n# multi-choice videoqa\n@registry.register_builder(\"nextqa\")\nclass NextQABuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/nextqa/defaults_qa.yaml\",\n    }\n@registry.register_builder(\"star\")\nclass STARBuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/star/defaults_qa.yaml\",\n    }\n\n@registry.register_builder(\"tvqa\")\nclass TVQABuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/tvqa/defaults_qa.yaml\",\n    }\n    \n@registry.register_builder(\"how2qa\")\nclass How2QABuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/how2qa/defaults_qa.yaml\",\n    }\n\n@registry.register_builder(\"vlep\")\nclass VLEPBuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/vlep/defaults_qa.yaml\",\n    }\n     \n@registry.register_builder(\"qvh\")\nclass QVHBuilder(MCVideoQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/qvh/defaults.yaml\",\n    }\n    \n# open-ended QA"
  },
  {
    "path": "lavis/datasets/builders/vqa_builder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder\n\nfrom lavis.common.registry import registry\nfrom lavis.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset\nfrom lavis.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset\nfrom lavis.datasets.datasets.vg_vqa_datasets import VGVQADataset\nfrom lavis.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset\n\n\n@registry.register_builder(\"coco_vqa\")\nclass COCOVQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = COCOVQADataset\n    eval_dataset_cls = COCOVQAEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/coco/defaults_vqa.yaml\",\n        \"eval\": \"configs/datasets/coco/eval_vqa.yaml\",\n    }\n\n\n@registry.register_builder(\"vg_vqa\")\nclass VGVQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = VGVQADataset\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/vg/defaults_vqa.yaml\"}\n\n\n@registry.register_builder(\"ok_vqa\")\nclass OKVQABuilder(COCOVQABuilder):\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/okvqa/defaults.yaml\",\n    }\n\n\n@registry.register_builder(\"aok_vqa\")\nclass AOKVQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = AOKVQADataset\n    eval_dataset_cls = AOKVQAEvalDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/aokvqa/defaults.yaml\"}\n\n\n@registry.register_builder(\"gqa\")\nclass GQABuilder(BaseDatasetBuilder):\n    train_dataset_cls = GQADataset\n    eval_dataset_cls = GQAEvalDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/gqa/defaults.yaml\",\n        \"balanced_val\": \"configs/datasets/gqa/balanced_val.yaml\",\n        \"balanced_testdev\": \"configs/datasets/gqa/balanced_testdev.yaml\",\n    }"
  },
  {
    "path": "lavis/datasets/data_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport gzip\nimport logging\nimport os\nimport random as rnd\nimport tarfile\nimport zipfile\n\nimport decord\nimport webdataset as wds\nimport numpy as np\nimport torch\nfrom torch.utils.data.dataset import IterableDataset, ChainDataset\nfrom decord import VideoReader\nfrom lavis.common.registry import registry\nfrom lavis.datasets.datasets.base_dataset import ConcatDataset\nfrom tqdm import tqdm\n\ndecord.bridge.set_bridge(\"torch\")\nMAX_INT = registry.get(\"MAX_INT\")\n\n# add for loading video\ndef load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling=\"uniform\", clip_proposal=None):\n    vr = VideoReader(uri=video_path, height=height, width=width)\n    vlen = len(vr)\n    n_frms = min(n_frms, vlen)\n    fps = vr.get_avg_fps() \n    if clip_proposal is None:\n        start, end = 0, vlen\n    else:\n        start, end = int(clip_proposal[0]*fps), int(clip_proposal[1]*fps)\n        if start < 0:\n            start = 0\n        if end > vlen:\n            end = vlen\n\n    intervals = np.linspace(start=start, stop=end, num=n_frms + 1).astype(int)\n    ranges = []\n    for idx, interv in enumerate(intervals[:-1]):\n        ranges.append((interv, intervals[idx + 1]))\n\n    if sampling == 'random':\n        indices = []\n        for x in ranges:\n            if x[0] == x[1]:\n                indices.append(x[0])\n            else:\n                indices.append(rnd.choice(range(x[0], x[1])))\n    elif sampling == 'uniform':\n        \n        indices = [(x[0] + x[1]) // 2 for x in ranges]\n\n    elif sampling == \"headtail\":\n        indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))\n        indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))\n        indices = indices_h + indices_t\n    else:\n        raise NotImplementedError\n    \n    if len(indices) < n_frms:\n        rest = [indices[-1] for i in range(n_frms - len(indices))]\n        indices = indices + rest \n    # get_batch -> T, H, W, C\n    frms = vr.get_batch(indices).permute(3, 0, 1, 2).float()  # (C, T, H, W)\n\n    return frms, indices, fps\n\ndef load_video_demo(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling=\"uniform\", clip_proposal=None):\n    vr = VideoReader(uri=video_path, height=height, width=width)\n    vlen = len(vr)\n    n_frms = min(n_frms, vlen)\n    fps = vr.get_avg_fps() \n    if clip_proposal is None:\n        start, end = 0, vlen\n    else:\n        start, end = int(clip_proposal[0]*fps), int(clip_proposal[1]*fps)\n        if start < 0:\n            start = 0\n        if end > vlen:\n            end = vlen\n\n    intervals = np.linspace(start=start, stop=end, num=n_frms + 1).astype(int)\n    ranges = []\n    for idx, interv in enumerate(intervals[:-1]):\n        ranges.append((interv, intervals[idx + 1]))\n\n    if sampling == 'random':\n        indices = []\n        for x in ranges:\n            if x[0] == x[1]:\n                indices.append(x[0])\n            else:\n                indices.append(rnd.choice(range(x[0], x[1])))\n    elif sampling == 'uniform':\n        \n        indices = [(x[0] + x[1]) // 2 for x in ranges]\n\n    elif sampling == \"headtail\":\n        indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))\n        indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))\n        indices = indices_h + indices_t\n    else:\n        raise NotImplementedError\n    \n    if len(indices) < n_frms:\n        rest = [indices[-1] for i in range(n_frms - len(indices))]\n        indices = indices + rest \n    # get_batch -> T, H, W, C\n    \n    frms = vr.get_batch(indices)\n    frms = frms.asnumpy()\n    frms = torch.from_numpy(frms)\n    frms = frms.permute(3, 0, 1, 2).float()  # (C, T, H, W)\n\n    return frms, indices, fps, vlen\n\ndef apply_to_sample(f, sample):\n    if len(sample) == 0:\n        return {}\n\n    def _apply(x):\n        if torch.is_tensor(x):\n            return f(x)\n        elif isinstance(x, dict):\n            return {key: _apply(value) for key, value in x.items()}\n        elif isinstance(x, list):\n            return [_apply(x) for x in x]\n        else:\n            return x\n\n    return _apply(sample)\n\n\ndef move_to_cuda(sample):\n    def _move_to_cuda(tensor):\n        return tensor.cuda()\n\n    return apply_to_sample(_move_to_cuda, sample)\n\n\ndef prepare_sample(samples, cuda_enabled=True):\n    if cuda_enabled:\n        samples = move_to_cuda(samples)\n\n    # TODO fp16 support\n\n    return samples\n\n\ndef reorg_datasets_by_split(datasets):\n    \"\"\"\n    Organizes datasets by split.\n\n    Args:\n        datasets: dict of torch.utils.data.Dataset objects by name.\n\n    Returns:\n        Dict of datasets by split {split_name: List[Datasets]}.\n    \"\"\"\n    # if len(datasets) == 1:\n    #     return datasets[list(datasets.keys())[0]]\n    # else:\n    reorg_datasets = dict()\n\n    # reorganize by split\n    for _, dataset in datasets.items():\n        for split_name, dataset_split in dataset.items():\n            if split_name not in reorg_datasets:\n                reorg_datasets[split_name] = [dataset_split]\n            else:\n                reorg_datasets[split_name].append(dataset_split)\n\n    return reorg_datasets\n\n\ndef concat_datasets(datasets):\n    \"\"\"\n    Concatenates multiple datasets into a single dataset.\n\n    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support\n    generic IterableDataset because it requires creating separate samplers.\n\n    Now only supports conctenating training datasets and assuming validation and testing\n    have only a single dataset. This is because metrics should not be computed on the concatenated\n    datasets.\n\n    Args:\n        datasets: dict of torch.utils.data.Dataset objects by split.\n\n    Returns:\n        Dict of concatenated datasets by split, \"train\" is the concatenation of multiple datasets,\n        \"val\" and \"test\" remain the same.\n\n        If the input training datasets contain both map-style and DataPipeline datasets, returns\n        a tuple, where the first element is a concatenated map-style dataset and the second\n        element is a chained DataPipeline dataset.\n\n    \"\"\"\n    # concatenate datasets in the same split\n    for split_name in datasets:\n        if split_name != \"train\":\n            assert (\n                len(datasets[split_name]) == 1\n            ), \"Do not support multiple {} datasets.\".format(split_name)\n            datasets[split_name] = datasets[split_name][0]\n        else:\n            iterable_datasets, map_datasets = [], []\n            for dataset in datasets[split_name]:\n                if isinstance(dataset, wds.DataPipeline):\n                    logging.info(\n                        \"Dataset {} is IterableDataset, can't be concatenated.\".format(\n                            dataset\n                        )\n                    )\n                    iterable_datasets.append(dataset)\n                elif isinstance(dataset, IterableDataset):\n                    raise NotImplementedError(\n                        \"Do not support concatenation of generic IterableDataset.\"\n                    )\n                else:\n                    map_datasets.append(dataset)\n\n            # if len(iterable_datasets) > 0:\n            # concatenate map-style datasets and iterable-style datasets separately\n            chained_datasets = (\n                ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None\n            )\n            concat_datasets = (\n                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None\n            )\n\n            train_datasets = concat_datasets, chained_datasets\n            train_datasets = tuple([x for x in train_datasets if x is not None])\n            train_datasets = (\n                train_datasets[0] if len(train_datasets) == 1 else train_datasets\n            )\n\n            datasets[split_name] = train_datasets\n\n    return datasets\n\n\ndef extract_archive(from_path, to_path=None, overwrite=False):\n    \"\"\"Extract archive.\n\n    Args:\n        from_path: the path of the archive.\n        to_path: the root path of the extracted files (directory of from_path)\n        overwrite: overwrite existing files (False)\n\n    Returns:\n        List of paths to extracted files even if not overwritten.\n\n    Examples:\n        >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'\n        >>> from_path = './validation.tar.gz'\n        >>> to_path = './'\n        >>> torchtext.utils.download_from_url(url, from_path)\n        >>> torchtext.utils.extract_archive(from_path, to_path)\n        >>> ['.data/val.de', '.data/val.en']\n        >>> torchtext.utils.download_from_url(url, from_path)\n        >>> torchtext.utils.extract_archive(from_path, to_path)\n        >>> ['.data/val.de', '.data/val.en']\n\n    \"\"\"\n\n    if to_path is None:\n        to_path = os.path.dirname(from_path)\n\n    if from_path.endswith((\".tar.gz\", \".tgz\")):\n        logging.info(\"Opening tar file {} to {}.\".format(from_path, to_path))\n        with tarfile.open(from_path, \"r\") as tar:\n            files = []\n            for file_ in tqdm(tar):\n                file_path = os.path.join(to_path, file_.name)\n                if file_.isfile():\n                    files.append(file_path)\n                    if os.path.exists(file_path):\n                        logging.info(\"{} already extracted.\".format(file_path))\n                        if not overwrite:\n                            continue\n                tar.extract(file_, to_path)\n            logging.info(\"Finished extracting tar file {}.\".format(from_path))\n            return files\n\n    elif from_path.endswith(\".zip\"):\n        assert zipfile.is_zipfile(from_path), from_path\n        logging.info(\"Opening zip file {} to {}.\".format(from_path, to_path))\n        with zipfile.ZipFile(from_path, \"r\") as zfile:\n            files = []\n            for file_ in tqdm(zfile.namelist()):\n                file_path = os.path.join(to_path, file_)\n                files.append(file_path)\n                if os.path.exists(file_path):\n                    logging.info(\"{} already extracted.\".format(file_path))\n                    if not overwrite:\n                        continue\n                zfile.extract(file_, to_path)\n        files = [f for f in files if os.path.isfile(f)]\n        logging.info(\"Finished extracting zip file {}.\".format(from_path))\n        return files\n\n    elif from_path.endswith(\".gz\"):\n        logging.info(\"Opening gz file {} to {}.\".format(from_path, to_path))\n        default_block_size = 65536\n        filename = from_path[:-3]\n        files = [filename]\n        with gzip.open(from_path, \"rb\") as gzfile, open(filename, \"wb\") as d_file:\n            while True:\n                block = gzfile.read(default_block_size)\n                if not block:\n                    break\n                else:\n                    d_file.write(block)\n            d_file.write(block)\n        logging.info(\"Finished extracting gz file {}.\".format(from_path))\n        return files\n\n    else:\n        raise NotImplementedError(\n            \"We currently only support tar.gz, .tgz, .gz and zip achives.\"\n        )\n\n\ndef save_frames_grid(img_array, out_path):\n    import torch\n    from PIL import Image\n    from torchvision.utils import make_grid\n\n    if len(img_array.shape) == 3:\n        img_array = img_array.unsqueeze(0)\n    elif len(img_array.shape) == 5:\n        b, t, c, h, w = img_array.shape\n        img_array = img_array.view(-1, c, h, w)\n    elif len(img_array.shape) == 4:\n        pass\n    else:\n        raise NotImplementedError(\n            \"Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.\"\n        )\n\n    assert img_array.shape[1] == 3, \"Exepcting input shape of (H, W, 3), i.e. RGB-only.\"\n\n    grid = make_grid(img_array)\n    ndarr = grid.permute(1, 2, 0).to(\"cpu\", torch.uint8).numpy()\n\n    img = Image.fromarray(ndarr)\n\n    img.save(out_path)\n"
  },
  {
    "path": "lavis/datasets/datasets/aok_vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom collections import OrderedDict\nimport json\nimport os\nimport torch\n\nfrom PIL import Image\n\nfrom lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"question\": ann[\"question\"],\n                \"question_id\": ann[\"question_id\"],\n                \"direct_answers\": \"; \".join(ann[\"direct_answers\"]),\n                \"choices\": \"; \".join(ann[\"choices\"]),\n                \"correct_choice\": ann[\"choices\"][ann[\"correct_choice_idx\"]],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass AOKVQADataset(VQADataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        answer_key = \"direct_answers\"\n\n        answer_weight = {}\n        for answer in ann[answer_key]:\n            if answer in answer_weight.keys():\n                answer_weight[answer] += 1 / len(ann[answer_key])\n            else:\n                answer_weight[answer] = 1 / len(ann[answer_key])\n\n        answers = list(answer_weight.keys())\n        weights = list(answer_weight.values())\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"answers\": answers,\n            \"weights\": weights,\n        }\n\n\nclass AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n\n        self.vis_root = vis_root\n\n        self.annotation = json.load(open(ann_paths[0]))\n\n        answer_list_path = ann_paths[1]\n        if os.path.exists(answer_list_path):\n            self.answer_list = json.load(open(answer_list_path))\n        else:\n            self.answer_list = None\n\n        try:\n            self.coco_fmt_qust_file = ann_paths[2]\n            self.coco_fmt_anno_file = ann_paths[3]\n        except IndexError:\n            self.coco_fmt_qust_file = None\n            self.coco_fmt_anno_file = None\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n    def collater(self, samples):\n        (\n            image_list,\n            question_list,\n            question_id_list,\n            instance_id_list,\n            choices_list,\n            correct_choice_idx_list,\n            direct_answers_list,\n        ) = ([], [], [], [], [], [], [])\n\n        for sample in samples:\n            image_list.append(sample[\"image\"])\n            question_list.append(sample[\"text_input\"])\n            question_id_list.append(sample[\"question_id\"])\n            instance_id_list.append(sample[\"instance_id\"])\n            choices_list.append(sample[\"choices\"])\n            correct_choice_idx_list.append(sample[\"correct_choice_idx\"])\n            direct_answers_list.append(sample[\"direct_answers\"])\n\n        return {\n            \"image\": torch.stack(image_list, dim=0),\n            \"text_input\": question_list,\n            \"question_id\": question_id_list,\n            \"instance_id\": instance_id_list,\n            \"choices\": choices_list,\n            \"correct_choice_idx\": correct_choice_idx_list,\n            \"direct_answers\": direct_answers_list,\n        }\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        choices = ann[\"choices\"]\n        if \"correct_choice_idx\" in ann:\n            correct_choice_idx = ann[\"correct_choice_idx\"]\n        else:\n            correct_choice_idx = None\n\n        if \"direct_answers\" in ann:\n            direct_answers = ann[\"direct_answers\"]\n        else:\n            direct_answers = None\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"question_id\": ann[\"question_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n            \"choices\": choices,\n            \"correct_choice_idx\": correct_choice_idx,\n            \"direct_answers\": direct_answers,\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/avsd_dialogue_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nfrom lavis.datasets.datasets.dialogue_datasets import (\n    DialogueDataset,\n    DialogueEvalDataset,\n)\n\n\nclass AVSDDialDataset(DialogueDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        vname = ann[\"image_id\"]\n\n        video = self.vis_processor(self.vis_root, vname)\n\n        dialogue = self.text_processor(ann)\n\n        # \"image_id\" is kept to stay compatible with the COCO evaluation format\n        return {\n            \"video_fts\": video[\"video_fts\"],\n            \"video_token_type_ids\": video[\"token_type_ids\"],\n            \"input_ids\": dialogue[\"input_ids\"],\n            \"token_type_ids\": dialogue[\"token_type_ids\"],\n            \"labels\": dialogue[\"labels\"],\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n\n    def collater(self, samples):\n\n        input_ids, token_type_ids, labels, video_fts, video_token_type_ids = (\n            [],\n            [],\n            [],\n            [],\n            [],\n        )\n\n        for i in samples:\n            input_ids.append(i[\"input_ids\"])\n            token_type_ids.append(i[\"token_type_ids\"])\n            labels.append(i[\"labels\"])\n            video_fts.append(i[\"video_fts\"])\n            video_token_type_ids.append(i[\"video_token_type_ids\"])\n\n        input_ids = self.text_processor.padding(input_ids)\n\n        labels = self.text_processor.padding(\n            labels, -1\n        )  # ignore token indice -1 by default\n        video_fts = self.vis_processor.padding(video_fts)\n\n        token_type_ids = self.text_processor.padding(token_type_ids)\n        video_token_type_ids = self.text_processor.padding(video_token_type_ids)\n        token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)\n\n        attn_mask = self.text_processor.get_attention_mask(input_ids)\n        video_mask = self.vis_processor.get_attention_mask(video_fts)\n        attn_mask = torch.cat([video_mask, attn_mask], dim=1)\n\n        video_labels = (\n            torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1\n        )  # ignore token indice -1 by default\n        labels = torch.cat([video_labels, labels], dim=1)\n\n        samples = {}\n        samples[\"input_ids\"] = input_ids\n        samples[\"token_type_ids\"] = token_type_ids\n        samples[\"labels\"] = labels\n        samples[\"video_fts\"] = video_fts\n        samples[\"attn_mask\"] = attn_mask\n\n        return samples\n\n\nclass AVSDDialEvalDataset(DialogueEvalDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        vname = ann[\"image_id\"]\n\n        video = self.vis_processor(self.vis_root, vname)\n\n        dialogue = self.text_processor(ann)\n\n        # \"image_id\" is kept to stay compatible with the COCO evaluation format\n        return {\n            \"video_fts\": video[\"video_fts\"],\n            \"video_token_type_ids\": video[\"token_type_ids\"],\n            \"input_ids\": dialogue[\"input_ids\"],\n            \"token_type_ids\": dialogue[\"token_type_ids\"],\n            \"labels\": dialogue[\"labels\"],\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n\n    def collater(self, samples):\n\n        input_ids, token_type_ids, labels, video_fts, video_token_type_ids = (\n            [],\n            [],\n            [],\n            [],\n            [],\n        )\n\n        for i in samples:\n            input_ids.append(i[\"input_ids\"])\n            token_type_ids.append(i[\"token_type_ids\"])\n            labels.append(i[\"labels\"])\n            video_fts.append(i[\"video_fts\"])\n            video_token_type_ids.append(i[\"video_token_type_ids\"])\n\n        input_ids = self.text_processor.padding(input_ids)\n\n        labels = self.text_processor.padding(\n            labels, -1\n        )  # ignore token indice -1 by default\n        video_fts = self.vis_processor.padding(video_fts)\n\n        token_type_ids = self.text_processor.padding(token_type_ids)\n        video_token_type_ids = self.text_processor.padding(video_token_type_ids)\n        token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)\n\n        attn_mask = self.text_processor.get_attention_mask(input_ids)\n        video_mask = self.vis_processor.get_attention_mask(video_fts)\n        attn_mask = torch.cat([video_mask, attn_mask], dim=1)\n\n        video_labels = (\n            torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1\n        )  # ignore token indice -1 by default\n        labels = torch.cat([video_labels, labels], dim=1)\n\n        samples = {}\n        samples[\"input_ids\"] = input_ids\n        samples[\"token_type_ids\"] = token_type_ids\n        samples[\"labels\"] = labels\n        samples[\"video_fts\"] = video_fts\n        samples[\"attn_mask\"] = attn_mask\n\n        return samples\n"
  },
  {
    "path": "lavis/datasets/datasets/base_dataset.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport pandas as pd\n\nfrom typing import Iterable\nfrom torch.utils.data import Dataset, ConcatDataset\nfrom torch.utils.data.dataloader import default_collate\n\n\nclass BaseDataset(Dataset):\n    def __init__(\n        self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]\n    ):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        self.vis_root = vis_root\n\n        self.annotation = []\n\n        for ann_path in ann_paths:\n            if '.json' in ann_path:\n                self.annotation.extend(json.load(open(ann_path, \"r\")))\n                if 'train' in ann_path: \n                    self.data_type = 'train'\n                else:\n                    self.data_type = 'val'\n            else:\n                raise AttributeError('Undefined data type')\n            \n        #self.annotation = self.annotation[:100] \n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n    def __len__(self):\n        return len(self.annotation)\n\n    def collater(self, samples):\n        return default_collate(samples)\n\n    def set_processors(self, vis_processor, text_processor):\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n    def _add_instance_ids(self, key=\"instance_id\"):\n        for idx, ann in enumerate(self.annotation): \n            if isinstance(ann, str):\n                pass\n            else:\n                ann[key] = str(idx)\n\n\nclass ConcatDataset(ConcatDataset):\n    def __init__(self, datasets: Iterable[Dataset]) -> None:\n        super().__init__(datasets)\n\n    def collater(self, samples):\n        # TODO For now only supports datasets with same underlying collater implementations\n\n        all_keys = set()\n        for s in samples:\n            all_keys.update(s)\n\n        shared_keys = all_keys\n        for s in samples:\n            shared_keys = shared_keys & set(s.keys())\n\n        samples_shared_keys = []\n        for s in samples:\n            samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})\n\n        return self.datasets[0].collater(samples_shared_keys)\n"
  },
  {
    "path": "lavis/datasets/datasets/caption_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"caption\": ann[\"caption\"],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass CaptionDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"image_id\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        # TODO this assumes image input, not general enough\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = self.text_processor(ann[\"caption\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n        }\n\n\nclass CaptionEvalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        return {\n            \"image\": image,\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/coco_caption_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport json\n\nfrom PIL import Image\nfrom PIL import ImageFile\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nfrom lavis.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset\n\nCOCOCapDataset = CaptionDataset\n\n\nclass COCOCapEvalDataset(CaptionEvalDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        img_id = ann[\"image\"].split(\"/\")[-1].strip(\".jpg\").split(\"_\")[-1]\n\n        return {\n            \"image\": image,\n            \"image_id\": img_id,\n            \"instance_id\": ann[\"instance_id\"],\n        }\n\n\nclass NoCapsEvalDataset(CaptionEvalDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        img_id = ann[\"img_id\"]\n\n        return {\n            \"image\": image,\n            \"image_id\": img_id,\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/coco_vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport json\n\nfrom PIL import Image\n\nfrom lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset\n\nfrom collections import OrderedDict\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"question\": ann[\"question\"],\n                \"question_id\": ann[\"question_id\"],\n                \"answers\": \"; \".join(ann[\"answer\"]),\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass COCOVQADataset(VQADataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        answer_weight = {}\n        for answer in ann[\"answer\"]:\n            if answer in answer_weight.keys():\n                answer_weight[answer] += 1 / len(ann[\"answer\"])\n            else:\n                answer_weight[answer] = 1 / len(ann[\"answer\"])\n\n        answers = list(answer_weight.keys())\n        weights = list(answer_weight.values())\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"answers\": answers,\n            \"weights\": weights,\n        }\n\n\nclass COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n\n        self.vis_root = vis_root\n\n        self.annotation = json.load(open(ann_paths[0]))\n\n        answer_list_path = ann_paths[1]\n        if os.path.exists(answer_list_path):\n            self.answer_list = json.load(open(answer_list_path))\n        else:\n            self.answer_list = None\n\n        try:\n            self.coco_fmt_qust_file = ann_paths[2]\n            self.coco_fmt_anno_file = ann_paths[3]\n        except IndexError:\n            self.coco_fmt_qust_file = None\n            self.coco_fmt_anno_file = None\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"question_id\": ann[\"question_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/dataloader_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport time\nimport random\nimport torch\nfrom lavis.datasets.data_utils import move_to_cuda\nfrom torch.utils.data import DataLoader\n\n\nclass MultiIterLoader:\n    \"\"\"\n    A simple wrapper for iterating over multiple iterators.\n\n    Args:\n        loaders (List[Loader]): List of Iterator loaders.\n        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.\n    \"\"\"\n\n    def __init__(self, loaders, ratios=None):\n        # assert all loaders has __next__ method\n        for loader in loaders:\n            assert hasattr(\n                loader, \"__next__\"\n            ), \"Loader {} has no __next__ method.\".format(loader)\n\n        if ratios is None:\n            ratios = [1.0] * len(loaders)\n        else:\n            assert len(ratios) == len(loaders)\n            ratios = [float(ratio) / sum(ratios) for ratio in ratios]\n\n        self.loaders = loaders\n        self.ratios = ratios\n\n    def __next__(self):\n        # random sample from each loader by ratio\n        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]\n        return next(self.loaders[loader_idx])\n\n\nclass PrefetchLoader(object):\n    \"\"\"\n    Modified from https://github.com/ChenRocks/UNITER.\n\n    overlap compute and cuda data transfer\n    (copied and then modified from nvidia apex)\n    \"\"\"\n\n    def __init__(self, loader):\n        self.loader = loader\n        self.stream = torch.cuda.Stream()\n\n    def __iter__(self):\n        loader_it = iter(self.loader)\n        self.preload(loader_it)\n        batch = self.next(loader_it)\n        while batch is not None:\n            is_tuple = isinstance(batch, tuple)\n            if is_tuple:\n                task, batch = batch\n\n            if is_tuple:\n                yield task, batch\n            else:\n                yield batch\n            batch = self.next(loader_it)\n\n    def __len__(self):\n        return len(self.loader)\n\n    def preload(self, it):\n        try:\n            self.batch = next(it)\n        except StopIteration:\n            self.batch = None\n            return\n        # if record_stream() doesn't work, another option is to make sure\n        # device inputs are created on the main stream.\n        # self.next_input_gpu = torch.empty_like(self.next_input,\n        #                                        device='cuda')\n        # self.next_target_gpu = torch.empty_like(self.next_target,\n        #                                         device='cuda')\n        # Need to make sure the memory allocated for next_* is not still in use\n        # by the main stream at the time we start copying to next_*:\n        # self.stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(self.stream):\n            self.batch = move_to_cuda(self.batch)\n            # more code for the alternative if record_stream() doesn't work:\n            # copy_ will record the use of the pinned source tensor in this\n            # side stream.\n            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)\n            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)\n            # self.next_input = self.next_input_gpu\n            # self.next_target = self.next_target_gpu\n\n    def next(self, it):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        batch = self.batch\n        if batch is not None:\n            record_cuda_stream(batch)\n        self.preload(it)\n        return batch\n\n    def __getattr__(self, name):\n        method = self.loader.__getattribute__(name)\n        return method\n\n\ndef record_cuda_stream(batch):\n    if isinstance(batch, torch.Tensor):\n        batch.record_stream(torch.cuda.current_stream())\n    elif isinstance(batch, list) or isinstance(batch, tuple):\n        for t in batch:\n            record_cuda_stream(t)\n    elif isinstance(batch, dict):\n        for t in batch.values():\n            record_cuda_stream(t)\n    else:\n        pass\n\n\nclass IterLoader:\n    \"\"\"\n    A wrapper to convert DataLoader as an infinite iterator.\n\n    Modified from:\n        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py\n    \"\"\"\n\n    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):\n        self._dataloader = dataloader\n        self.iter_loader = iter(self._dataloader)\n        self._use_distributed = use_distributed\n        self._epoch = 0\n\n    @property\n    def epoch(self) -> int:\n        return self._epoch\n\n    def __next__(self):\n        try:\n            data = next(self.iter_loader)\n        except StopIteration:\n            self._epoch += 1\n            if hasattr(self._dataloader.sampler, \"set_epoch\") and self._use_distributed:\n                self._dataloader.sampler.set_epoch(self._epoch)\n            time.sleep(2)  # Prevent possible deadlock during epoch transition\n            self.iter_loader = iter(self._dataloader)\n            data = next(self.iter_loader)\n\n        return data\n\n    def __iter__(self):\n        return self\n\n    def __len__(self):\n        return len(self._dataloader)\n"
  },
  {
    "path": "lavis/datasets/datasets/dialogue_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom PIL import Image\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\n\nimport json\nimport copy\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"dialogue\": ann[\"dialogue\"],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass DialogueDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n\n        self.vis_root = vis_root\n\n        self.annotation = []\n        for ann_path in ann_paths:\n            dialogs = json.load(open(ann_path, \"r\"))[\"dialogs\"]\n            for dialog in dialogs:\n                all_turns = dialog[\"dialog\"]\n                dialogue_context = []\n                for turn in all_turns:\n                    dialog_instance = copy.deepcopy(dialog)\n                    question = turn[\"question\"]\n                    answer = turn[\"answer\"]\n\n                    dialog_instance[\"dialog\"] = copy.deepcopy(dialogue_context)\n                    dialog_instance[\"question\"] = question\n                    dialog_instance[\"answer\"] = answer\n                    self.annotation.append(dialog_instance)\n                    dialogue_context.append(turn)\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"image_id\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = self.text_processor(ann[\"caption\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n        }\n\n\nclass DialogueEvalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n\n        self.vis_root = vis_root\n\n        self.annotation = []\n        for ann_path in ann_paths:\n            dialogs = json.load(open(ann_path, \"r\"))[\"dialogs\"]\n            for dialog in dialogs:\n                all_turns = dialog[\"dialog\"]\n                dialogue_context = all_turns[:-1]\n                last_turn = all_turns[-1]\n\n                question = last_turn[\"question\"]\n                answer = last_turn[\"answer\"]\n\n                dialog[\"dialog\"] = dialogue_context\n                dialog[\"question\"] = question\n                dialog[\"answer\"] = answer\n\n                self.annotation.append(dialog)\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"image_id\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        return {\n            \"image\": image,\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/gqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport json\n\nfrom PIL import Image\n\nfrom lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset\n\nfrom collections import OrderedDict\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"question\": ann[\"question\"],\n                \"question_id\": ann[\"question_id\"],\n                \"answers\": \"; \".join(ann[\"answer\"]),\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass GQADataset(VQADataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        answers = [ann[\"answer\"]]\n        weights = [1]\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"answers\": answers,\n            \"weights\": weights,\n        }\n\n\nclass GQAEvalDataset(VQAEvalDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. gqa/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n\n        self.vis_root = vis_root\n\n        self.annotation = json.load(open(ann_paths[0]))\n\n        ## TODO: support inference method == 'ranking'\n        answer_list_path = ann_paths[1] if len(ann_paths) > 1 else ''\n        if os.path.exists(answer_list_path):\n            self.answer_list = json.load(open(answer_list_path))\n        else:\n            self.answer_list = None\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        if \"answer\" in ann:\n            # answer is a string\n            answer = ann[\"answer\"]\n        else:\n            answer = None\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"answer\": answer,\n            \"question_id\": ann[\"question_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/image_text_pair_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": os.path.basename(ann[\"image\"]),\n                \"caption\": ann[\"caption\"],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass ImageTextPairDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        # TODO this assumes image input, not general enough\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = self.text_processor(ann[\"caption\"])\n\n        return {\"image\": image, \"text_input\": caption}\n"
  },
  {
    "path": "lavis/datasets/datasets/imagefolder_dataset.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\nfrom PIL import Image\nfrom torchvision import datasets\n\n\nclass ImageFolderDataset(BaseDataset):\n    def __init__(self, vis_processor, vis_root, classnames=[], **kwargs):\n        super().__init__(vis_processor=vis_processor, vis_root=vis_root)\n\n        self.inner_dataset = datasets.ImageFolder(vis_root)\n\n        self.annotation = [\n            {\"image\": elem[0], \"label\": elem[1], \"image_id\": elem[0]}\n            for elem in self.inner_dataset.imgs\n        ]\n\n        self.classnames = classnames\n\n        self._add_instance_ids()\n\n    def __len__(self):\n        return len(self.inner_dataset)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        img_fn = ann[\"image\"]\n        image_path = os.path.join(self.vis_root, img_fn)\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        return {\n            \"image\": image,\n            \"label\": ann[\"label\"],\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"label\": self.classnames[ann[\"label\"]],\n                \"image\": sample[\"image\"],\n            }\n        )\n"
  },
  {
    "path": "lavis/datasets/datasets/laion_dataset.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport webdataset as wds\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\n\n\nclass LaionDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, location):\n        super().__init__(vis_processor=vis_processor, text_processor=text_processor)\n\n        self.inner_dataset = wds.DataPipeline(\n            wds.ResampledShards(location),\n            wds.tarfile_to_samples(handler=wds.warn_and_continue),\n            wds.shuffle(1000, handler=wds.warn_and_continue),\n            wds.decode(\"pilrgb\", handler=wds.warn_and_continue),\n            wds.to_tuple(\"jpg\", \"json\", handler=wds.warn_and_continue),\n            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),\n            wds.map(self.to_dict, handler=wds.warn_and_continue),\n        )\n\n    def to_dict(self, sample):\n        return {\n            \"image\": sample[0],\n            \"text_input\": self.text_processor(sample[1][\"caption\"]),\n        }\n\n\nif __name__ == \"__main__\":\n    from torchvision import transforms\n\n    def to_image_text_pair(sample):\n        return sample[0], sample[1][\"caption\"]\n\n    normalize = transforms.Normalize(\n        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)\n    )\n\n    transform_train = transforms.Compose(\n        [\n            transforms.RandomResizedCrop(256, scale=(0.2, 1.0)),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            normalize,\n        ]\n    )\n\n    dataset = LaionDataset(\n        vis_processor=transform_train,\n        text_processor=lambda x: x,\n        location=\"/export/laion/laion2B-multi/part-00000/{00000..01743}.tar\",\n    )\n\n    import torch\n\n    loader = torch.utils.data.DataLoader(dataset.inner_dataset, batch_size=2)\n\n    print(next(iter(loader))[\"text_input\"])\n"
  },
  {
    "path": "lavis/datasets/datasets/mc_video_vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport os\nimport torch\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.multimodal_classification_datasets import (\n    MultimodalClassificationDataset,\n)\nimport random\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        ann = self.annotation[index]\n        vname = ann[\"video\"]\n        vpath = os.path.join(self.vis_root, vname)\n\n        return OrderedDict(\n            {\"file\": vpath, \"question\": ann[\"question\"], \"answer\": ann[\"answer\"]}\n        )\n\nANS_MAPPING = {0:'A',1:'B',2:'C',3:'D',4:'E'}\n# NextQA\nclass MCVideoQADataset(MultimodalClassificationDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def _load_auxiliary_mappings(self):\n        pass\n    \n    def _get_answer_label(self, answer):\n        if answer in self.class_labels:\n            return self.class_labels[answer]\n        else:\n            return len(self.class_labels)\n\n    def __getitem__(self, index):\n        \n        result = None\n        while result is None:\n\n            ann = self.annotation[index]\n            qid = ann['qid'] \n\n            if 'QVHighlight' in qid:\n                q = ann['query']\n            else:\n                q = ann['question']\n            \n            # set video clip if 'start'&'end' timestamp in data\n            if 'start' in ann:\n                start, end = float(ann['start']), float(ann['end'])\n                clip = [start, end]\n            else:\n                clip = None       \n            \n            if 'VLEP' in qid:\n                qa_prompt = 'Upon observing the provided frames, what is the most probable subsequent event?'\n                events = 'Option A: ' + ann['a0'] + ' Option B: ' + ann['a1']\n                qa_prompt = qa_prompt + ' ' + events\n                loc_prompt = 'Does the information within the frame provide the necessary details to predict next event?'\n                loc_prompt = qa_prompt + ' ' + loc_prompt\n                answers = 'Option ' + ANS_MAPPING[int(ann['answer'])]\n                duration = 1\n\n            elif 'QVHighlight' in qid:\n                duration = ann['duration']\n                if 'relevant_windows' in ann: \n                    relevant_windows = ann['relevant_windows']\n                else:\n                    relevant_windows = None # for test\n                pseudo_options = 'Option A: yes. Option B: no.'\n                if q[-1] != '.':\n                    q += '.'      \n                loc_prompt = 'Question: ' + q +  ' ' + pseudo_options + ' Does the information within the frame provide the necessary details to accurately answer the given question?'\n                qa_prompt = 'Considering the information presented in the frame, select the correct answer from the options.'\n                \n                \n            else:\n                prompt = 'Question: ' + q\n                for j in range(ann['num_option']):\n                    a = ann['a{}'.format(j)]\n                    prompt += ' Option {}: '.format(ANS_MAPPING[j])\n                    prompt += a\n                hints = 'Options: ('\n                #hints = 'Captions: ('\n                for j in range(ann['num_option']):\n                    ans = ann['a{}'.format(str(j))]\n                    hints += ans\n                    hints += ' '\n                hints += ')'\n                qa_prompt = prompt + ' Considering the information presented in the frame, select the correct answer from the options.'\n                loc_prompt = 'Question: ' + q +  ' ' + hints + ' Does the information within the frame provide the necessary details to accurately answer the given question?'                \n                answers = 'Option ' + ANS_MAPPING[int(ann['answer'])]\n                duration = 1\n            \n            try:\n                if 'VLEP' in qid:\n                    video_id = ann['video']\n                    if ':' in video_id:\n                        # we set absolute path for vlep as it takes multiple video source\n                        # you may change below paths to you own path\n                        video_path = '/nas-hdd/shoubin/vlep_ytb_clips_tars/videos/vlep_ytb_clips/'\n                    else:\n                        video_id = video_id[:-3]\n                        video_path = '/nas-hdd/shoubin/videos/tvqa/videos_3fps_with_audio/'\n                    vpath = os.path.join(video_path, video_id + '.mp4')\n                else:\n                    vpath = os.path.join(self.vis_root, str(ann['video']) + '.mp4')   \n                    \n                frms, indices, fps = self.vis_processor(vpath, clip_proposal=clip)\n                frms = frms.permute(1, 0, 2, 3)\n                assert len(frms) == self.vis_processor.n_frms\n                \n                if 'QVHighlight' in qid: \n                    time_stamp = [float(idx/fps) for idx in indices]\n                    answers = []\n                    if relevant_windows is not None:\n                        for t in time_stamp:\n                            flag = False\n                            for span in relevant_windows:\n                                if t >= float(span[0]) and t<= float(span[1]):\n                                    answers.append('yes')\n                                    flag = True \n                                    break\n                            if not flag:\n                                answers.append('no') \n                    else:\n                        for t in time_stamp:\n                            answers.append('no') # for test\n                            \n                    answers = '_'.join(answers)\n                              \n                result = True\n            except Exception as e:\n                \n                print(f\"Error while read file idx\")\n                print(\"video is: {}\".format(ann['video']))\n                index = random.randint(0, len(self.annotation) - 1)\n                \n        return {\n            \"video\": frms,\n            \"qa_input\": qa_prompt,\n            \"loc_input\": loc_prompt,\n            \"qa_output\": answers,\n            \"question_id\": qid,\n            'duration': duration\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/multimodal_classification_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom abc import abstractmethod\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\n\n\nclass MultimodalClassificationDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.class_labels = None\n\n    @abstractmethod\n    def _build_class_labels(self):\n        pass\n\n    @abstractmethod\n    def _load_auxiliary_mappings(self):\n        pass\n\n"
  },
  {
    "path": "lavis/datasets/datasets/nlvr_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport random\n\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.multimodal_classification_datasets import (\n    MultimodalClassificationDataset,\n)\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file_L\": ann[\"images\"][0],\n                \"file_R\": ann[\"images\"][1],\n                \"sentence\": ann[\"sentence\"],\n                \"label\": ann[\"label\"],\n                \"image\": [sample[\"image0\"], sample[\"image1\"]],\n            }\n        )\n\n\nclass NLVRDataset(MultimodalClassificationDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.class_labels = self._build_class_labels()\n\n    def _build_class_labels(self):\n        return {\"False\": 0, \"True\": 1}\n\n    @staticmethod\n    def _flip(samples):\n        sentence = samples[\"text_input\"]\n        image0, image1 = samples[\"image0\"], samples[\"image1\"]\n\n        if \"left\" not in sentence and \"right\" not in sentence:\n            if random.random() < 0.5:\n                image0, image1 = image1, image0\n        else:\n            if random.random() < 0.5:\n                sentence = sentence.replace(\"left\", \"[TEMP_TOKEN]\")\n                sentence = sentence.replace(\"right\", \"left\")\n                sentence = sentence.replace(\"[TEMP_TOKEN]\", \"right\")\n\n                image0, image1 = image1, image0\n\n        samples[\"text_input\"] = sentence\n        samples[\"image0\"] = image0\n        samples[\"image1\"] = image1\n\n        return samples\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image0_path = os.path.join(self.vis_root, ann[\"images\"][0])\n        image0 = Image.open(image0_path).convert(\"RGB\")\n        image0 = self.vis_processor(image0)\n\n        image1_path = os.path.join(self.vis_root, ann[\"images\"][1])\n        image1 = Image.open(image1_path).convert(\"RGB\")\n        image1 = self.vis_processor(image1)\n\n        sentence = self.text_processor(ann[\"sentence\"])\n        label = self.class_labels[ann[\"label\"]]\n\n        return self._flip(\n            {\n                \"image0\": image0,\n                \"image1\": image1,\n                \"text_input\": sentence,\n                \"label\": label,\n                # \"image_id\": ann[\"image_id\"],\n                \"instance_id\": ann[\"instance_id\"],\n            }\n        )\n\n\nclass NLVREvalDataset(NLVRDataset):\n    @staticmethod\n    def _flip(samples):\n        return samples\n"
  },
  {
    "path": "lavis/datasets/datasets/retrieval_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n        visual_key = \"image\" if \"image\" in ann else \"video\"\n\n        return OrderedDict(\n            {\n                \"file\": ann[visual_key],\n                \"caption\": ann[\"caption\"],\n                visual_key: sample[visual_key],\n            }\n        )\n\n\nclass RetrievalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"image_id\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = self.text_processor(ann[\"caption\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n\n\nclass RetrievalEvalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.text = []\n        self.image = []\n        self.txt2img = {}\n        self.img2txt = {}\n\n        txt_id = 0\n        for img_id, ann in enumerate(self.annotation):\n            self.image.append(ann[\"image\"])\n            self.img2txt[img_id] = []\n            for i, caption in enumerate(ann[\"caption\"]):\n                self.text.append(self.text_processor(caption))\n                self.img2txt[img_id].append(txt_id)\n                self.txt2img[txt_id] = img_id\n                txt_id += 1\n\n    def __getitem__(self, index):\n\n        image_path = os.path.join(self.vis_root, self.annotation[index][\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        return {\"image\": image, \"index\": index}\n\n\nclass VideoRetrievalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of videos.\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"video\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        vpath = os.path.join(self.vis_root, ann[\"video\"])\n\n        video = self.vis_processor(vpath)\n        caption = self.text_processor(ann[\"caption\"])\n\n        # return image, caption, self.img_ids[ann['image_id']]\n        return {\n            \"video\": video,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"video\"]],\n        }\n\n\nclass VideoRetrievalEvalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of videos.\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.text = []\n        self.image = []\n        self.txt2img = {}\n        self.img2txt = {}\n\n        txt_id = 0\n        for img_id, ann in enumerate(self.annotation):\n            self.image.append(ann[\"video\"])\n            self.img2txt[img_id] = []\n            for i, caption in enumerate(ann[\"caption\"]):\n                self.text.append(self.text_processor(caption))\n                self.img2txt[img_id].append(txt_id)\n                self.txt2img[txt_id] = img_id\n                txt_id += 1\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        vpath = os.path.join(self.vis_root, ann[\"video\"])\n        video = self.vis_processor(vpath)\n\n        return {\"video\": video, \"index\": index}\n"
  },
  {
    "path": "lavis/datasets/datasets/snli_ve_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.multimodal_classification_datasets import (\n    MultimodalClassificationDataset,\n)\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": os.path.basename(ann[\"image\"]),\n                \"sentence\": ann[\"sentence\"],\n                \"label\": ann[\"label\"],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass SNLIVisualEntialmentDataset(MultimodalClassificationDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.class_labels = self._build_class_labels()\n\n    def _build_class_labels(self):\n        return {\"contradiction\": 0, \"neutral\": 1, \"entailment\": 2}\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_id = ann[\"image\"]\n        image_path = os.path.join(self.vis_root, \"%s.jpg\" % image_id)\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        sentence = self.text_processor(ann[\"sentence\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": sentence,\n            \"label\": self.class_labels[ann[\"label\"]],\n            \"image_id\": image_id,\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/vg_vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\n\nfrom PIL import Image\n\nfrom lavis.datasets.datasets.vqa_datasets import VQADataset\n\n\nclass VGVQADataset(VQADataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        question = self.text_processor(ann[\"question\"])\n\n        answers = [ann[\"answer\"]]\n        # TODO this should be configured better\n        weights = [0.2]\n\n        return {\n            \"image\": image,\n            \"text_input\": question,\n            \"answers\": answers,\n            \"weights\": weights,\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/video_caption_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\n\nfrom lavis.datasets.datasets.caption_datasets import CaptionDataset\n\n\nclass VideoCaptionDataset(CaptionDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        vname = ann[\"video\"]\n        video_path = os.path.join(self.vis_root, vname)\n\n        video = self.vis_processor(video_path)\n        caption = self.text_processor(ann[\"caption\"])\n\n        # \"image_id\" is kept to stay compatible with the COCO evaluation format\n        return {\n            \"video\": video,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n        }\n\n\nclass VideoCaptionEvalDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        vname = ann[\"video\"]\n        video_path = os.path.join(self.vis_root, vname)\n\n        video = self.vis_processor(video_path)\n\n        return {\n            \"video\": video,\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/video_vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport os\nfrom collections import OrderedDict\n\nfrom lavis.datasets.datasets.multimodal_classification_datasets import (\n    MultimodalClassificationDataset,\n)\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        ann = self.annotation[index]\n        vname = ann[\"video\"]\n        vpath = os.path.join(self.vis_root, vname)\n\n        return OrderedDict(\n            {\"file\": vpath, \"question\": ann[\"question\"], \"answer\": ann[\"answer\"]}\n        )\n\n\nclass VideoQADataset(MultimodalClassificationDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def _build_class_labels(self, ans_path):\n        ans2label = json.load(open(ans_path))\n\n        self.class_labels = ans2label\n\n    def _get_answer_label(self, answer):\n        if answer in self.class_labels:\n            return self.class_labels[answer]\n        else:\n            return len(self.class_labels)\n\n    def __getitem__(self, index):\n        assert (\n            self.class_labels\n        ), f\"class_labels of {__class__.__name__} is not built yet.\"\n\n        ann = self.annotation[index]\n\n        vname = ann[\"video\"]\n        vpath = os.path.join(self.vis_root, vname)\n\n        frms = self.vis_processor(vpath)\n        question = self.text_processor(ann[\"question\"])\n\n        return {\n            \"video\": frms,\n            \"text_input\": question,\n            \"answers\": self._get_answer_label(ann[\"answer\"]),\n            \"question_id\": ann[\"question_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "lavis/datasets/datasets/vqa_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\n\nfrom lavis.datasets.datasets.base_dataset import BaseDataset\n\n\nclass VQADataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def collater(self, samples):\n        image_list, question_list, answer_list, weight_list = [], [], [], []\n\n        num_answers = []\n\n        for sample in samples:\n            image_list.append(sample[\"image\"])\n            question_list.append(sample[\"text_input\"])\n\n            weight_list.extend(sample[\"weights\"])\n\n            answers = sample[\"answers\"]\n\n            answer_list.extend(answers)\n            num_answers.append(len(answers))\n\n        return {\n            \"image\": torch.stack(image_list, dim=0),\n            \"text_input\": question_list,\n            \"answer\": answer_list,\n            \"weight\": torch.Tensor(weight_list),\n            \"n_answers\": torch.LongTensor(num_answers),\n        }\n\n\nclass VQAEvalDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/LICENSE",
    "content": "// Copyright 2022 Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven Hoi. All rights reserved.\n// Use of this source code is governed by a BSD-style\n// license that can be found in the LICENSE file.\n\nMIT License\n\nCopyright (c) 2019 Igor Brigadir\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/README.md",
    "content": "<!--\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n-->\n\n# Download Conceptual Captions Data\n\nPlace data from: https://ai.google.com/research/ConceptualCaptions/download in this folder\n\n`Train_GCC-training.tsv / cc3m.tsv` Training Split (3,318,333)\n\nrun `download_data_cc3m.py` or `download_data_cc12m.py`.\n\nImages will be in default LAVIS cache folders. You can stop and resume, the settings for splitting downloads into chunks / threads are not optimal, but it maxed out my connection so i kept them as is.\n\nNote: A previous version of this script used a different file naming scheme, this changed and if you are resuming a previously started download, you will get duplicates.\n\nA bunch of them will fail to download, and return web pages instead. These will need to be cleaned up later. See `downloaded_validation_report.tsv` after it downloads for HTTP errors. Around 8% of images are gone, based on validation set results. Setting the user agent could fix some errors too maybe - not sure if any requests are rejected by sites based on this.\n\nIt should take about a day or two to download the training data, keep an eye on disk space.\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_12m.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"from lavis.common.utils import get_abs_path, get_cache_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cc12m = pd.read_csv(\\\"downloaded_cc12m_report.tsv.gz\\\", compression=\\\"gzip\\\", sep=\\\"\\\\t\\\", names=[\\\"caption\\\", \\\"path\\\", \\\"dataset\\\", \\\"mimetype\\\", \\\"size\\\", \\\"status\\\", \\\"url\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"caption                            a very typical bus station\\n\",\n       \"path        /export/home/.cache/lavis/conceptual_caption/i...\\n\",\n       \"dataset                                                  cc3m\\n\",\n       \"mimetype                                           image/jpeg\\n\",\n       \"size                                                    36078\\n\",\n       \"status                                                    200\\n\",\n       \"url         http://lh6.ggpht.com/-IvRtNLNcG8o/TpFyrudaT6I/...\\n\",\n       \"Name: 0, dtype: object\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cc12m.iloc[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"3318333\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(cc12m)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|██████████| 3130587/3130587 [17:28<00:00, 2986.08it/s]\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Found 2759017 valid records\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cnt = 0\\n\",\n    \"\\n\",\n    \"valid_records = []\\n\",\n    \"\\n\",\n    \"for i, path in tqdm(enumerate(cc12m.path.unique()), total=len(cc12m.path.unique())):\\n\",\n    \"    path = str(path)\\n\",\n    \"    if os.path.exists(path):\\n\",\n    \"        record = cc12m.iloc[i]\\n\",\n    \"        valid_records.append({\\\"image\\\": record[\\\"path\\\"], \\\"caption\\\": record[\\\"caption\\\"]})\\n\",\n    \"\\n\",\n    \"        cnt += 1\\n\",\n    \"\\n\",\n    \"print(\\\"Found {} valid records\\\".format(cnt))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"2759017\"\n      ]\n     },\n     \"execution_count\": 22,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(valid_records)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'image': '/export/home/.cache/lavis/conceptual_caption/images/1_3239086386.jpg',\\n\",\n       \" 'caption': 'sierra looked stunning in this top and this skirt while performing with person at their former university'}\"\n      ]\n     },\n     \"execution_count\": 24,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"valid_records[1]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/export/home/.cache/lavis/conceptual_caption/annotations/cc3m.json already exists\\n\"\n     ]\n    },\n    {\n     \"ename\": \"\",\n     \"evalue\": \"\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"config_path = get_abs_path(\\\"configs/datasets/conceptual_caption/defaults_12m.yaml\\\")\\n\",\n    \"\\n\",\n    \"ann_path = OmegaConf.load(\\n\",\n    \"    config_path\\n\",\n    \").datasets.conceptual_caption_12m.build_info.annotations.train.storage[0]\\n\",\n    \"\\n\",\n    \"ann_path = get_cache_path(ann_path)\\n\",\n    \"\\n\",\n    \"if os.path.exists(ann_path):\\n\",\n    \"    # abort\\n\",\n    \"    print(\\\"{} already exists\\\".format(ann_path))\\n\",\n    \"else:\\n\",\n    \"    # Save the valid records to a json file\\n\",\n    \"    with open(ann_path, \\\"w\\\") as f:\\n\",\n    \"        f.write(json.dumps(valid_records))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.10 ('base')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  },\n  \"orig_nbformat\": 4,\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_3m.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"from lavis.common.utils import get_abs_path, get_cache_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cc3m = pd.read_csv(\\\"downloaded_cc3m_report.tsv.gz\\\", compression=\\\"gzip\\\", sep=\\\"\\\\t\\\", names=[\\\"caption\\\", \\\"path\\\", \\\"dataset\\\", \\\"mimetype\\\", \\\"size\\\", \\\"status\\\", \\\"url\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"caption                            a very typical bus station\\n\",\n       \"path        /export/home/.cache/lavis/conceptual_caption/i...\\n\",\n       \"dataset                                                  cc3m\\n\",\n       \"mimetype                                           image/jpeg\\n\",\n       \"size                                                    36078\\n\",\n       \"status                                                    200\\n\",\n       \"url         http://lh6.ggpht.com/-IvRtNLNcG8o/TpFyrudaT6I/...\\n\",\n       \"Name: 0, dtype: object\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cc3m.iloc[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"3318333\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(cc3m)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|██████████| 3130587/3130587 [17:28<00:00, 2986.08it/s]\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Found 2759017 valid records\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cnt = 0\\n\",\n    \"\\n\",\n    \"valid_records = []\\n\",\n    \"\\n\",\n    \"for i, path in tqdm(enumerate(cc3m.path.unique()), total=len(cc3m.path.unique())):\\n\",\n    \"    path = str(path)\\n\",\n    \"    if os.path.exists(path):\\n\",\n    \"        record = cc3m.iloc[i]\\n\",\n    \"        valid_records.append({\\\"image\\\": record[\\\"path\\\"], \\\"caption\\\": record[\\\"caption\\\"]})\\n\",\n    \"\\n\",\n    \"        cnt += 1\\n\",\n    \"\\n\",\n    \"print(\\\"Found {} valid records\\\".format(cnt))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"2759017\"\n      ]\n     },\n     \"execution_count\": 22,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(valid_records)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'image': '/export/home/.cache/lavis/conceptual_caption/images/1_3239086386.jpg',\\n\",\n       \" 'caption': 'sierra looked stunning in this top and this skirt while performing with person at their former university'}\"\n      ]\n     },\n     \"execution_count\": 24,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"valid_records[1]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/export/home/.cache/lavis/conceptual_caption/annotations/cc3m.json already exists\\n\"\n     ]\n    },\n    {\n     \"ename\": \"\",\n     \"evalue\": \"\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"config_path = get_abs_path(\\\"configs/datasets/conceptual_caption/defaults_3m.yaml\\\")\\n\",\n    \"\\n\",\n    \"ann_path = OmegaConf.load(\\n\",\n    \"    config_path\\n\",\n    \").datasets.conceptual_caption_3m.build_info.annotations.train.storage[0]\\n\",\n    \"\\n\",\n    \"ann_path = get_cache_path(ann_path)\\n\",\n    \"\\n\",\n    \"if os.path.exists(ann_path):\\n\",\n    \"    # abort\\n\",\n    \"    print(\\\"{} already exists\\\".format(ann_path))\\n\",\n    \"else:\\n\",\n    \"    # Save the valid records to a json file\\n\",\n    \"    with open(ann_path, \\\"w\\\") as f:\\n\",\n    \"        f.write(json.dumps(valid_records))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.10 ('base')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  },\n  \"orig_nbformat\": 4,\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport time\nfrom PIL import Image\nfrom lavis.common.utils import get_abs_path, get_cache_path\nfrom multiprocessing import Pool\nfrom omegaconf import OmegaConf\nfrom pathlib import Path\nfrom torchvision.transforms import functional as TF\nfrom tqdm import tqdm\nimport glob\nimport io\nimport json\nimport magic  # pip install python-magic\nimport numpy as np\nimport os\nimport pandas as pd\nimport requests\nimport shelve\nimport zlib\n\nheaders = {\n    #'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36',\n    \"User-Agent\": \"Googlebot-Image/1.0\",  # Pretend to be googlebot\n    \"X-Forwarded-For\": \"64.18.15.200\",\n}\n\n\ndef _df_split_apply(tup_arg):\n    split_ind, subset, func = tup_arg\n    r = subset.apply(func, axis=1)\n    return (split_ind, r)\n\n\ndef df_multiprocess(df, processes, chunk_size, func, dataset_name):\n    print(\"Generating parts...\")\n    with shelve.open(\n        \"%s_%s_%s_results.tmp\" % (dataset_name, func.__name__, chunk_size)\n    ) as results:\n\n        pbar = tqdm(total=len(df), position=0)\n        # Resume:\n        finished_chunks = set([int(k) for k in results.keys()])\n        pbar.desc = \"Resuming\"\n        for k in results.keys():\n            pbar.update(len(results[str(k)][1]))\n\n        pool_data = (\n            (index, df[i : i + chunk_size], func)\n            for index, i in enumerate(range(0, len(df), chunk_size))\n            if index not in finished_chunks\n        )\n        print(\n            int(len(df) / chunk_size),\n            \"parts.\",\n            chunk_size,\n            \"per part.\",\n            \"Using\",\n            processes,\n            \"processes\",\n        )\n\n        pbar.desc = \"Downloading\"\n        with Pool(processes) as pool:\n            for i, result in enumerate(\n                pool.imap_unordered(_df_split_apply, pool_data, 2)\n            ):\n                results[str(result[0])] = result\n                pbar.update(len(result[1]))\n        pbar.close()\n\n    print(\"Finished Downloading.\")\n    return\n\n\n# Unique name based on url\ndef _file_name(row):\n    name = (\n        \"%s/%s_%s\"\n        % (\n            # row[\"folder\"],\n            storage_dir,\n            row.name,\n            (zlib.crc32(row[\"url\"].encode(\"utf-8\")) & 0xFFFFFFFF),\n        )\n        + \".jpg\"\n    )\n    return name\n\n\n# For checking mimetypes separately without download\ndef check_mimetype(row):\n    if os.path.isfile(str(row[\"file\"])):\n        row[\"mimetype\"] = magic.from_file(row[\"file\"], mime=True)\n        row[\"size\"] = os.stat(row[\"file\"]).st_size\n    return row\n\n\n# Don't download image, just check with a HEAD request, can't resume.\n# Can use this instead of download_image to get HTTP status codes.\ndef check_download(row):\n    fname = _file_name(row)\n    try:\n        # not all sites will support HEAD\n        response = requests.head(\n            row[\"url\"], stream=False, timeout=5, allow_redirects=True, headers=headers\n        )\n        row[\"status\"] = response.status_code\n        row[\"headers\"] = dict(response.headers)\n    except:\n        # log errors later, set error as 408 timeout\n        row[\"status\"] = 408\n        return row\n    if response.ok:\n        row[\"file\"] = fname\n    return row\n\n\ndef resize_img(req):\n    image = Image.open(req).convert(\"RGB\")\n    image = TF.resize(\n        # image, size=(resize_size, resize_size)\n        image,\n        size=resize_size,\n    )  # , interpolation=Image.LANCZOS)\n    return image\n\n\ndef download_image(row):\n    fname = _file_name(row)\n    # Skip Already downloaded, retry others later\n    if os.path.isfile(fname):\n        row[\"status\"] = 200\n        row[\"file\"] = fname\n        row[\"mimetype\"] = magic.from_file(row[\"file\"], mime=True)\n        row[\"size\"] = os.stat(row[\"file\"]).st_size\n        return row\n\n    try:\n        # use smaller timeout to skip errors, but can result in failed downloads\n        response = requests.get(\n            row[\"url\"], stream=False, timeout=5, allow_redirects=True, headers=headers\n        )\n        row[\"status\"] = response.status_code\n        # row['headers'] = dict(response.headers)\n    except Exception as e:\n        # log errors later, set error as 408 timeout\n        row[\"status\"] = 408\n        return row\n\n    if response.ok:\n        try:\n            # some sites respond with gzip transport encoding\n            response.raw.decode_content = True\n            img = resize_img(io.BytesIO(response.content))\n            img.save(fname)\n\n            row[\"mimetype\"] = magic.from_file(fname, mime=True)\n            row[\"size\"] = os.stat(fname).st_size\n\n        except Exception as e:\n            #     # This is if it times out during a download or decode\n            row[\"status\"] = 408\n\n    row[\"file\"] = fname\n    return row\n\n\ndef open_tsv(fname, folder):\n    print(\"Opening %s Data File...\" % fname)\n    df = pd.read_csv(\n        fname, sep=\"\\t\", names=[\"url\", \"caption\"]\n    )  # , usecols=range(1, 2))\n    df[\"folder\"] = folder\n    print(\"Processing\", len(df), \" Images:\")\n    return df\n\n\ndef df_from_shelve(chunk_size, func, dataset_name):\n    print(\"Generating Dataframe from results...\")\n    with shelve.open(\n        \"%s_%s_%s_results.tmp\" % (dataset_name, func.__name__, chunk_size)\n    ) as results:\n        keylist = sorted([int(k) for k in results.keys()])\n        df = pd.concat([results[str(k)][1] for k in keylist], sort=True)\n    return df\n\n\nresize_size = 384\n\nconfig_path = get_abs_path(\"configs/datasets/conceptual_caption/defaults_12m.yaml\")\n\nstorage_dir = OmegaConf.load(\n    config_path\n).datasets.conceptual_caption_12m.build_info.images.storage\nstorage_dir = Path(get_cache_path(storage_dir))\n\nos.makedirs(storage_dir, exist_ok=True)\n\n# number of processes in the pool can be larger than cores\nnum_processes = 96\n# num_processes = 1\n# chunk_size is how many images per chunk per process - changing this resets progress when restarting.\nimages_per_part = 100\n\ndata_name = \"cc12m\"\n# os.makedirs(data_name, exist_ok=True)\n\ndf = open_tsv(\"cc12m.tsv\", data_name)\ndf_multiprocess(\n    df=df,\n    processes=num_processes,\n    chunk_size=images_per_part,\n    func=download_image,\n    dataset_name=data_name,\n)\ndf = df_from_shelve(\n    chunk_size=images_per_part, func=download_image, dataset_name=data_name\n)\ndf.to_csv(\n    \"downloaded_%s_report.tsv.gz\" % data_name,\n    compression=\"gzip\",\n    sep=\"\\t\",\n    header=False,\n    index=False,\n)\nprint(\"Saved.\")\n"
  },
  {
    "path": "lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport glob\nfrom pathlib import Path\nimport time\nfrom omegaconf import OmegaConf\nimport pandas as pd\nimport numpy as np\nimport requests\nimport zlib\nimport os\nimport io\nimport shelve\nfrom lavis.common.utils import get_abs_path, get_cache_path\nimport magic  # pip install python-magic\nimport json\nfrom multiprocessing import Pool\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom torchvision.transforms import functional as TF\n\nheaders = {\n    #'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36',\n    \"User-Agent\": \"Googlebot-Image/1.0\",  # Pretend to be googlebot\n    \"X-Forwarded-For\": \"64.18.15.200\",\n}\n\n\ndef _df_split_apply(tup_arg):\n    split_ind, subset, func = tup_arg\n    r = subset.apply(func, axis=1)\n    return (split_ind, r)\n\n\ndef df_multiprocess(df, processes, chunk_size, func, dataset_name):\n    print(\"Generating parts...\")\n    with shelve.open(\n        \"%s_%s_%s_results.tmp\" % (dataset_name, func.__name__, chunk_size)\n    ) as results:\n\n        pbar = tqdm(total=len(df), position=0)\n        # Resume:\n        finished_chunks = set([int(k) for k in results.keys()])\n        pbar.desc = \"Resuming\"\n        for k in results.keys():\n            pbar.update(len(results[str(k)][1]))\n\n        pool_data = (\n            (index, df[i : i + chunk_size], func)\n            for index, i in enumerate(range(0, len(df), chunk_size))\n            if index not in finished_chunks\n        )\n        print(\n            int(len(df) / chunk_size),\n            \"parts.\",\n            chunk_size,\n            \"per part.\",\n            \"Using\",\n            processes,\n            \"processes\",\n        )\n\n        pbar.desc = \"Downloading\"\n        with Pool(processes) as pool:\n            for i, result in enumerate(\n                pool.imap_unordered(_df_split_apply, pool_data, 2)\n            ):\n                results[str(result[0])] = result\n                pbar.update(len(result[1]))\n        pbar.close()\n\n    print(\"Finished Downloading.\")\n    return\n\n\n# Unique name based on url\ndef _file_name(row):\n    name = (\n        \"%s/%s_%s\"\n        % (\n            # row[\"folder\"],\n            storage_dir,\n            row.name,\n            (zlib.crc32(row[\"url\"].encode(\"utf-8\")) & 0xFFFFFFFF),\n        )\n        + \".jpg\"\n    )\n    return name\n\n\n# For checking mimetypes separately without download\ndef check_mimetype(row):\n    if os.path.isfile(str(row[\"file\"])):\n        row[\"mimetype\"] = magic.from_file(row[\"file\"], mime=True)\n        row[\"size\"] = os.stat(row[\"file\"]).st_size\n    return row\n\n\n# Don't download image, just check with a HEAD request, can't resume.\n# Can use this instead of download_image to get HTTP status codes.\ndef check_download(row):\n    fname = _file_name(row)\n    try:\n        # not all sites will support HEAD\n        response = requests.head(\n            row[\"url\"], stream=False, timeout=5, allow_redirects=True, headers=headers\n        )\n        row[\"status\"] = response.status_code\n        row[\"headers\"] = dict(response.headers)\n    except:\n        # log errors later, set error as 408 timeout\n        row[\"status\"] = 408\n        return row\n    if response.ok:\n        row[\"file\"] = fname\n    return row\n\n\ndef resize_img(req):\n    image = Image.open(req).convert(\"RGB\")\n    image = TF.resize(\n        # image, size=(resize_size, resize_size)\n        image,\n        size=resize_size,\n    )  # , interpolation=Image.LANCZOS)\n    return image\n\n\ndef download_image(row):\n    fname = _file_name(row)\n    # Skip Already downloaded, retry others later\n    if os.path.isfile(fname):\n        row[\"status\"] = 200\n        row[\"file\"] = fname\n        row[\"mimetype\"] = magic.from_file(row[\"file\"], mime=True)\n        row[\"size\"] = os.stat(row[\"file\"]).st_size\n        return row\n\n    try:\n        # use smaller timeout to skip errors, but can result in failed downloads\n        response = requests.get(\n            row[\"url\"], stream=False, timeout=5, allow_redirects=True, headers=headers\n        )\n        row[\"status\"] = response.status_code\n        # row['headers'] = dict(response.headers)\n    except Exception as e:\n        # log errors later, set error as 408 timeout\n        row[\"status\"] = 408\n        return row\n\n    if response.ok:\n        try:\n            # some sites respond with gzip transport encoding\n            response.raw.decode_content = True\n            img = resize_img(io.BytesIO(response.content))\n            img.save(fname)\n\n            row[\"mimetype\"] = magic.from_file(fname, mime=True)\n            row[\"size\"] = os.stat(fname).st_size\n\n        except Exception as e:\n            #     # This is if it times out during a download or decode\n            row[\"status\"] = 408\n\n    row[\"file\"] = fname\n    return row\n\n\ndef open_tsv(fname, folder):\n    print(\"Opening %s Data File...\" % fname)\n    df = pd.read_csv(\n        fname, sep=\"\\t\", names=[\"caption\", \"url\"]\n    )  # , usecols=range(1, 2))\n    df[\"folder\"] = folder\n    print(\"Processing\", len(df), \" Images:\")\n    return df\n\n\ndef df_from_shelve(chunk_size, func, dataset_name):\n    print(\"Generating Dataframe from results...\")\n    with shelve.open(\n        \"%s_%s_%s_results.tmp\" % (dataset_name, func.__name__, chunk_size)\n    ) as results:\n        keylist = sorted([int(k) for k in results.keys()])\n        df = pd.concat([results[str(k)][1] for k in keylist], sort=True)\n    return df\n\n\nresize_size = 384\n\nconfig_path = get_abs_path(\"configs/datasets/conceptual_caption/defaults_3m.yaml\")\n\nstorage_dir = OmegaConf.load(\n    config_path\n).datasets.conceptual_caption_3m.build_info.images.storage\nstorage_dir = Path(get_cache_path(storage_dir))\n\nos.makedirs(storage_dir, exist_ok=True)\n\n# number of processes in the pool can be larger than cores\nnum_processes = 32\n# chunk_size is how many images per chunk per process - changing this resets progress when restarting.\nimages_per_part = 100\n\ndata_name = \"cc3m\"\ndf = open_tsv(\"Train_GCC-training.tsv\", data_name)\ndf_multiprocess(\n    df=df,\n    processes=num_processes,\n    chunk_size=images_per_part,\n    func=download_image,\n    dataset_name=data_name,\n)\ndf = df_from_shelve(\n    chunk_size=images_per_part, func=download_image, dataset_name=data_name\n)\ndf.to_csv(\n    \"downloaded_%s_report.tsv.gz\" % data_name,\n    compression=\"gzip\",\n    sep=\"\\t\",\n    header=False,\n    index=False,\n)\nprint(\"Saved.\")\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_coco.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\n\nDATA_URL = {\n    \"train\": \"http://images.cocodataset.org/zips/train2014.zip\",  # md5: 0da8c0bd3d6becc4dcb32757491aca88\n    \"val\": \"http://images.cocodataset.org/zips/val2014.zip\",  # md5: a3d79f5ed8d289b7a7554ce06a5782b3\n    \"test\": \"http://images.cocodataset.org/zips/test2014.zip\",  # md5: 04127eef689ceac55e3a572c2c92f264\n    \"test2015\": \"http://images.cocodataset.org/zips/test2015.zip\",  # md5: 04127eef689ceac55e3a572c2c92f264\n}\n\n\ndef download_datasets(root, url):\n    download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir)\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/coco/defaults_cap.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.coco_caption.build_info.images.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        for k, v in DATA_URL.items():\n            print(\"Downloading {} to {}\".format(v, k))\n            download_datasets(download_dir, v)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_didemo.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\nDATA_URL = \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/didemo_videos.tar.gz\"\n\n\ndef download_datasets(root, url):\n    \"\"\"\n    Download the Imagenet-R dataset archives and expand them\n    in the folder provided as parameter\n    \"\"\"\n    download_and_extract_archive(url=url, download_root=root)\n\n\ndef move_files(download_path, storage_path):\n    \"\"\"\n    Move files from download_path to storage_path\n    \"\"\"\n    print(\"Moving to {}\".format(storage_path))\n\n    os.makedirs(storage_path, exist_ok=True)\n\n    for file_name in os.listdir(download_path):\n        os.rename(\n            os.path.join(download_path, file_name),\n            os.path.join(storage_path, file_name),\n        )\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/didemo/defaults_ret.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.didemo_retrieval.build_info.videos.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        print(\"Downloading {} to {}\".format(DATA_URL, download_dir))\n        download_datasets(download_dir, DATA_URL)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    move_files(download_dir / \"videos\", storage_dir)\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_flickr.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    get_abs_path,\n    get_cache_path,\n)\n\nimport opendatasets as od\n\n\nDATA_URL = \"https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset\"\n\nprint(\n    \"\"\"\n    To download the dataset, you need to have a Kaggle account and the associated key.\n    See https://www.kaggle.com/docs/api to create account and a new API token.\n    \"\"\"\n)\n\n\ndef move_directory(src_dir, dst_dir):\n    \"\"\"\n    Move files from download_path to storage_path\n    \"\"\"\n    print(\"Moving to {}\".format(dst_dir))\n\n    os.makedirs(dst_dir, exist_ok=True)\n\n    for file_name in os.listdir(src_dir):\n        os.rename(\n            os.path.join(src_dir, file_name),\n            os.path.join(dst_dir, file_name),\n        )\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/flickr30k/defaults.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.flickr30k.build_info.images.storage\n\n    storage_dir = Path(get_cache_path(storage_dir))\n    download_dir = storage_dir.parent / \"download\"\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    os.makedirs(download_dir)\n\n    try:\n        print(\"Downloading {} to {}\".format(DATA_URL, download_dir))\n        od.download(DATA_URL, download_dir)\n    except Exception as e:\n        print(e)\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        exit(1)\n\n    move_directory(\n        download_dir / \"flickr-image-dataset\" / \"flickr30k_images\" / \"flickr30k_images\",\n        storage_dir / \"flickr30k-images\",\n    )\n\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_gqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\n\nDATA_URL = \"https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip\"\n\n\ndef download_datasets(root, url):\n    download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir.parent)\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/gqa/defaults.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.gqa.build_info.images.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        print(\"Downloading {}\".format(DATA_URL))\n        download_datasets(download_dir, DATA_URL)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_msrvtt.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\n\n# TODO\n# 1. Go to https://www.mediafire.com/file/czh8sezbo9s4692/test_videos.zip/file\n#      and https://www.mediafire.com/file/x3rrbe4hwp04e6w/train_val_videos.zip/file\n# 2. Right-click the Download button and copy the link address\n#      e.g.\n#    DATA_URL = {\n#        \"train\": \"https://download1602.mediafire.com/xxxxxxxxxxxx/x3rrbe4hwp04e6w/train_val_videos.zip\",\n#        \"test\": \"https://download2390.mediafire.com/xxxxxxxxxxxx/czh8sezbo9s4692/test_videos.zip\",\n#    }\n# 3. Paste the link address to DATA_URL\n\nDATA_URL = {\n    \"train\": \"https://download2295.mediafire.com/4bb7p74xrbgg/x3rrbe4hwp04e6w/train_val_videos.zip\",\n    \"test\": \"https://download2390.mediafire.com/79hfq3592lqg/czh8sezbo9s4692/test_videos.zip\",\n}\n\n\ndef download_datasets(root, url):\n    \"\"\"\n    Download the Imagenet-R dataset archives and expand them\n    in the folder provided as parameter\n    \"\"\"\n    download_and_extract_archive(url=url, download_root=root)\n\n\ndef merge_datasets(download_path, storage_path):\n    \"\"\"\n    Merge datasets in download_path to storage_path\n    \"\"\"\n\n    # Merge train and test datasets\n    train_path = os.path.join(download_path, \"TrainValVideo\")\n    test_path = os.path.join(download_path, \"TestVideo\")\n    train_test_path = storage_path\n\n    print(\"Merging to {}\".format(train_test_path))\n\n    os.makedirs(train_test_path, exist_ok=True)\n\n    for file_name in os.listdir(train_path):\n        os.rename(\n            os.path.join(train_path, file_name),\n            os.path.join(train_test_path, file_name),\n        )\n\n    for file_name in os.listdir(test_path):\n        os.rename(\n            os.path.join(test_path, file_name),\n            os.path.join(train_test_path, file_name),\n        )\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/msrvtt/defaults_cap.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.msrvtt_cap.build_info.videos.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        for k, v in DATA_URL.items():\n            print(\"Downloading {} to {}\".format(v, k))\n            download_datasets(download_dir, v)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    try:\n        merge_datasets(download_dir, storage_dir)\n    except Exception as e:\n        # remove storage dir if failed\n        cleanup_dir(download_dir)\n        cleanup_dir(storage_dir)\n        print(\"Failed to merging datasets. Aborting.\")\n\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_msvd.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\n\nDATA_URL = \"https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar\"\n\n\ndef download_datasets(root, url):\n    download_and_extract_archive(url=url, download_root=root)\n\n\ndef move_files(download_path, storage_path):\n    \"\"\"\n    Move files from download_path to storage_path\n    \"\"\"\n    print(\"Moving to {}\".format(storage_path))\n\n    os.makedirs(storage_path, exist_ok=True)\n\n    for file_name in os.listdir(download_path):\n        os.rename(\n            os.path.join(download_path, file_name),\n            os.path.join(storage_path, file_name),\n        )\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/msvd/defaults_cap.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.msvd_cap.build_info.videos.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        print(\"Downloading {}\".format(DATA_URL))\n        download_datasets(download_dir, DATA_URL)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    move_files(download_dir / \"YouTubeClips\", storage_dir)\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_nocaps.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport logging\nimport os\nimport time\nfrom multiprocessing import Pool\n\nimport numpy as np\nimport requests\nimport tqdm\nfrom lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path\nfrom omegaconf import OmegaConf\n\nheader_mzl = {\n    \"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36\",\n    # \"User-Agent\": \"Googlebot-Image/1.0\",  # Pretend to be googlebot\n    # \"X-Forwarded-For\": \"64.18.15.200\",\n}\n\nheader_gbot = {\n    \"User-Agent\": \"Googlebot-Image/1.0\",  # Pretend to be googlebot\n}\n\nheaders = [header_mzl, header_gbot]\n\n# Setup\nlogging.basicConfig(filename=\"download_nocaps.log\", filemode=\"w\", level=logging.INFO)\nrequests.packages.urllib3.disable_warnings(\n    requests.packages.urllib3.exceptions.InsecureRequestWarning\n)\n\n\ndef download_file(url, filename):\n    max_retries = 20\n    cur_retries = 0\n\n    header = headers[0]\n\n    while cur_retries < max_retries:\n        try:\n            r = requests.get(url, headers=header, timeout=10)\n            with open(filename, \"wb\") as f:\n                f.write(r.content)\n\n            break\n        except Exception as e:\n            logging.info(\" \".join(repr(e).splitlines()))\n            logging.error(url)\n            cur_retries += 1\n\n            # random sample a header from headers\n            header = headers[np.random.randint(0, len(headers))]\n\n    time.sleep(3 + cur_retries * 2)\n\n\ndef download_image_from_url_val(url):\n    basename = os.path.basename(url)\n    filename = os.path.join(storage_dir, \"val\", basename)\n\n    download_file(url, filename)\n\n\ndef download_image_from_url_test(url):\n    basename = os.path.basename(url)\n    filename = os.path.join(storage_dir, \"test\", basename)\n\n    download_file(url, filename)\n\n\nif __name__ == \"__main__\":\n    os.makedirs(\"tmp\", exist_ok=True)\n\n    # storage dir\n    config_path = get_abs_path(\"configs/datasets/nocaps/defaults.yaml\")\n\n    storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage\n    storage_dir = get_cache_path(storage_dir)\n    # make sure the storage dir exists\n    os.makedirs(storage_dir, exist_ok=True)\n    print(\"Storage dir:\", storage_dir)\n\n    # make sure the storage dir for val and test exists\n    os.makedirs(os.path.join(storage_dir, \"val\"), exist_ok=True)\n    os.makedirs(os.path.join(storage_dir, \"test\"), exist_ok=True)\n\n    # download annotations\n    val_url = \"https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json\"\n    tst_url = \"https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json\"\n\n    print(\"Downloading validation annotations from %s\" % val_url)\n    download_file(val_url, \"tmp/nocaps_val_ann.json\")\n    print(\"Downloading testing annotations from %s\" % tst_url)\n    download_file(tst_url, \"tmp/nocaps_tst_ann.json\")\n\n    # open annotations\n    val_ann = json.load(open(\"tmp/nocaps_val_ann.json\"))\n    tst_ann = json.load(open(\"tmp/nocaps_tst_ann.json\"))\n\n    # collect image urls\n    val_info = val_ann[\"images\"]\n    tst_info = tst_ann[\"images\"]\n\n    val_urls = [info[\"coco_url\"] for info in val_info]\n    tst_urls = [info[\"coco_url\"] for info in tst_info]\n\n    # setup multiprocessing\n    # large n_procs possibly causes server to reject requests\n    n_procs = 16\n\n    with Pool(n_procs) as pool:\n        print(\"Downloading validation images...\")\n        list(\n            tqdm.tqdm(\n                pool.imap(download_image_from_url_val, val_urls), total=len(val_urls)\n            )\n        )\n\n    with Pool(n_procs) as pool:\n        print(\"Downloading test images...\")\n        list(\n            tqdm.tqdm(\n                pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls)\n            )\n        )\n\n    # clean tmp\n    cleanup_dir(\"tmp\")\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_sbu.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport io\nimport os\nimport pathlib\nimport urllib\nimport tqdm\n\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom lavis.common.utils import get_abs_path, get_cache_path\nfrom lavis.datasets.builders import load_dataset\nfrom omegaconf import OmegaConf\nfrom PIL import Image\n\n# DATA_URL = {\"train\": \"http://www.cs.rice.edu/~vo9/sbucaptions/sbu_images.tar\"}\n\nUSER_AGENT = (\n    \"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:15.0) Gecko/20100101 Firefox/15.0.1\"\n)\n\n\ndef fetch_single_image(image_url, timeout=None, retries=0):\n    for _ in range(retries + 1):\n        try:\n            request = urllib.request.Request(\n                image_url,\n                data=None,\n                headers={\"user-agent\": USER_AGENT},\n            )\n            with urllib.request.urlopen(request, timeout=timeout) as req:\n                image = Image.open(io.BytesIO(req.read()))\n            break\n        except Exception:\n            image = None\n    return image\n\n\ndef download_and_save_image(ann, save_dir, timeout=None, retries=0):\n    image = fetch_single_image(ann[\"url\"], timeout=timeout, retries=retries)\n\n    if image is not None:\n        image_path = os.path.join(save_dir, ann[\"image\"])\n        print(image_path)\n        image.save(image_path)\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/sbu_caption/defaults.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.sbu_caption.build_info.images.storage\n\n    storage_dir = pathlib.Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    storage_dir.mkdir(parents=True, exist_ok=True)\n\n    num_threads = 20\n    dset = load_dataset(\"sbu_caption\")[\"train\"].annotation\n\n    print(\"Downloading dataset...\")\n    # multiprocessing\n    with ThreadPoolExecutor(max_workers=num_threads) as executor:\n        for ann in tqdm.tqdm(dset):\n            executor.submit(\n                download_and_save_image,\n                ann,\n                storage_dir,\n                timeout=30,\n                retries=10,\n            )\n"
  },
  {
    "path": "lavis/datasets/download_scripts/download_vg.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nfrom lavis.common.utils import (\n    cleanup_dir,\n    download_and_extract_archive,\n    get_abs_path,\n    get_cache_path,\n)\n\n\nDATA_URL = {\n    \"train\": \"https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip\",\n    \"train2\": \"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip\",\n}\n\n\ndef download_datasets(root, url):\n    download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir)\n\n\nif __name__ == \"__main__\":\n\n    config_path = get_abs_path(\"configs/datasets/vg/defaults_caption.yaml\")\n\n    storage_dir = OmegaConf.load(\n        config_path\n    ).datasets.vg_caption.build_info.images.storage\n\n    download_dir = Path(get_cache_path(storage_dir)).parent / \"download\"\n    storage_dir = Path(get_cache_path(storage_dir))\n\n    if storage_dir.exists():\n        print(f\"Dataset already exists at {storage_dir}. Aborting.\")\n        exit(0)\n\n    try:\n        for k, v in DATA_URL.items():\n            print(\"Downloading {} to {}\".format(v, k))\n            download_datasets(download_dir, v)\n    except Exception as e:\n        # remove download dir if failed\n        cleanup_dir(download_dir)\n        print(\"Failed to download or extracting datasets. Aborting.\")\n\n    cleanup_dir(download_dir)\n"
  },
  {
    "path": "lavis/models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nfrom omegaconf import OmegaConf\nfrom lavis.common.registry import registry\n\nfrom lavis.models.base_model import BaseModel\n\nfrom lavis.models.albef_models.albef_classification import AlbefClassification\nfrom lavis.models.albef_models.albef_feature_extractor import AlbefFeatureExtractor\nfrom lavis.models.albef_models.albef_nlvr import AlbefNLVR\nfrom lavis.models.albef_models.albef_pretrain import AlbefPretrain\nfrom lavis.models.albef_models.albef_retrieval import AlbefRetrieval\nfrom lavis.models.albef_models.albef_vqa import AlbefVQA\nfrom lavis.models.alpro_models.alpro_qa import AlproQA\nfrom lavis.models.alpro_models.alpro_retrieval import AlproRetrieval\n\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_caption import BlipCaption\nfrom lavis.models.blip_models.blip_classification import BlipClassification\nfrom lavis.models.blip_models.blip_feature_extractor import BlipFeatureExtractor\nfrom lavis.models.blip_models.blip_image_text_matching import BlipITM\nfrom lavis.models.blip_models.blip_nlvr import BlipNLVR\nfrom lavis.models.blip_models.blip_pretrain import BlipPretrain\nfrom lavis.models.blip_models.blip_retrieval import BlipRetrieval\nfrom lavis.models.blip_models.blip_vqa import BlipVQA\n\nfrom lavis.models.blip2_models.blip2 import Blip2Base\nfrom lavis.models.blip2_models.blip2_opt import Blip2OPT\nfrom lavis.models.blip2_models.blip2_t5 import Blip2T5\nfrom lavis.models.blip2_models.blip2_fmr import Blip2FMR\nfrom lavis.models.sevila_models.sevila import SeViLA\n\nfrom lavis.models.blip2_models.blip2_qformer import Blip2Qformer\nfrom lavis.models.blip2_models.blip2_image_text_matching import Blip2ITM\n\nfrom lavis.models.pnp_vqa_models.pnp_vqa import PNPVQA\nfrom lavis.models.pnp_vqa_models.pnp_unifiedqav2_fid import PNPUnifiedQAv2FiD\nfrom lavis.models.img2prompt_models.img2prompt_vqa import Img2PromptVQA\nfrom lavis.models.med import XBertLMHeadDecoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom lavis.models.clip_models.model import CLIP\n\nfrom lavis.models.gpt_models.gpt_dialogue import GPTDialogue\n\nfrom lavis.processors.base_processor import BaseProcessor\n\n\n__all__ = [\n    \"load_model\",\n    \"AlbefClassification\",\n    \"AlbefFeatureExtractor\",\n    \"AlbefNLVR\",\n    \"AlbefVQA\",\n    \"AlbefPretrain\",\n    \"AlbefRetrieval\",\n    \"AlproQA\",\n    \"AlproRetrieval\",\n    \"BaseModel\",\n    \"BlipBase\",\n    \"BlipFeatureExtractor\",\n    \"BlipCaption\",\n    \"BlipClassification\",\n    \"BlipITM\",\n    \"BlipNLVR\",\n    \"BlipPretrain\",\n    \"BlipRetrieval\",\n    \"BlipVQA\",\n    \"Blip2Qformer\",\n    \"Blip2Base\",\n    \"Blip2ITM\",\n    \"Blip2OPT\",\n    \"Blip2T5\",\n    \"PNPVQA\",\n    \"Img2PromptVQA\",\n    \"PNPUnifiedQAv2FiD\",\n    \"CLIP\",\n    \"VisionTransformerEncoder\",\n    \"XBertLMHeadDecoder\",\n    \"GPTDialogue\",\n    \"Blip2FMR\",\n    \"SeViLA\",\n]\n\n\ndef load_model(name, model_type, is_eval=False, device=\"cpu\", checkpoint=None):\n    \"\"\"\n    Load supported models.\n\n    To list all available models and types in registry:\n    >>> from lavis.models import model_zoo\n    >>> print(model_zoo)\n\n    Args:\n        name (str): name of the model.\n        model_type (str): type of the model.\n        is_eval (bool): whether the model is in eval mode. Default: False.\n        device (str): device to use. Default: \"cpu\".\n        checkpoint (str): path or to checkpoint. Default: None.\n            Note that expecting the checkpoint to have the same keys in state_dict as the model.\n\n    Returns:\n        model (torch.nn.Module): model.\n    \"\"\"\n\n    model = registry.get_model_class(name).from_pretrained(model_type=model_type)\n\n    if checkpoint is not None:\n        model.load_checkpoint(checkpoint)\n\n    if is_eval:\n        model.eval()\n\n    if device == \"cpu\":\n        model = model.float()\n\n    return model.to(device)\n\n\ndef load_preprocess(config):\n    \"\"\"\n    Load preprocessor configs and construct preprocessors.\n\n    If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.\n\n    Args:\n        config (dict): preprocessor configs.\n\n    Returns:\n        vis_processors (dict): preprocessors for visual inputs.\n        txt_processors (dict): preprocessors for text inputs.\n\n        Key is \"train\" or \"eval\" for processors used in training and evaluation respectively.\n    \"\"\"\n\n    def _build_proc_from_cfg(cfg):\n        return (\n            registry.get_processor_class(cfg.name).from_config(cfg)\n            if cfg is not None\n            else BaseProcessor()\n        )\n\n    vis_processors = dict()\n    txt_processors = dict()\n\n    vis_proc_cfg = config.get(\"vis_processor\")\n    txt_proc_cfg = config.get(\"text_processor\")\n\n    if vis_proc_cfg is not None:\n        vis_train_cfg = vis_proc_cfg.get(\"train\")\n        vis_eval_cfg = vis_proc_cfg.get(\"eval\")\n    else:\n        vis_train_cfg = None\n        vis_eval_cfg = None\n\n    vis_processors[\"train\"] = _build_proc_from_cfg(vis_train_cfg)\n    vis_processors[\"eval\"] = _build_proc_from_cfg(vis_eval_cfg)\n\n    if txt_proc_cfg is not None:\n        txt_train_cfg = txt_proc_cfg.get(\"train\")\n        txt_eval_cfg = txt_proc_cfg.get(\"eval\")\n    else:\n        txt_train_cfg = None\n        txt_eval_cfg = None\n\n    txt_processors[\"train\"] = _build_proc_from_cfg(txt_train_cfg)\n    txt_processors[\"eval\"] = _build_proc_from_cfg(txt_eval_cfg)\n\n    return vis_processors, txt_processors\n\n\ndef load_model_and_preprocess(name, model_type, is_eval=False, device=\"cpu\"):\n    \"\"\"\n    Load model and its related preprocessors.\n\n    List all available models and types in registry:\n    >>> from lavis.models import model_zoo\n    >>> print(model_zoo)\n\n    Args:\n        name (str): name of the model.\n        model_type (str): type of the model.\n        is_eval (bool): whether the model is in eval mode. Default: False.\n        device (str): device to use. Default: \"cpu\".\n\n    Returns:\n        model (torch.nn.Module): model.\n        vis_processors (dict): preprocessors for visual inputs.\n        txt_processors (dict): preprocessors for text inputs.\n    \"\"\"\n    model_cls = registry.get_model_class(name)\n\n    # load model\n    model = model_cls.from_pretrained(model_type=model_type)\n\n    if is_eval:\n        model.eval()\n\n    # load preprocess\n    cfg = OmegaConf.load(model_cls.default_config_path(model_type))\n    # print(cfg)\n    if cfg is not None:\n        preprocess_cfg = cfg.preprocess\n\n        vis_processors, txt_processors = load_preprocess(preprocess_cfg)\n    else:\n        vis_processors, txt_processors = None, None\n        logging.info(\n            f\"\"\"No default preprocess for model {name} ({model_type}).\n                This can happen if the model is not finetuned on downstream datasets,\n                or it is not intended for direct use without finetuning.\n            \"\"\"\n        )\n\n    if device == \"cpu\":\n        model = model.float()\n\n    return model.to(device), vis_processors, txt_processors\n\n\nclass ModelZoo:\n    \"\"\"\n    A utility class to create string representation of available model architectures and types.\n\n    >>> from lavis.models import model_zoo\n    >>> # list all available models\n    >>> print(model_zoo)\n    >>> # show total number of models\n    >>> print(len(model_zoo))\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.model_zoo = {\n            k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())\n            for k, v in registry.mapping[\"model_name_mapping\"].items()\n        }\n\n    def __str__(self) -> str:\n        return (\n            \"=\" * 50\n            + \"\\n\"\n            + f\"{'Architectures':<30} {'Types'}\\n\"\n            + \"=\" * 50\n            + \"\\n\"\n            + \"\\n\".join(\n                [\n                    f\"{name:<30} {', '.join(types)}\"\n                    for name, types in self.model_zoo.items()\n                ]\n            )\n        )\n\n    def __iter__(self):\n        return iter(self.model_zoo.items())\n\n    def __len__(self):\n        return sum([len(v) for v in self.model_zoo.values()])\n\n\nmodel_zoo = ModelZoo()\n"
  },
  {
    "path": "lavis/models/albef_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport logging\nimport os\nimport time\n\nimport lavis.common.dist_utils as dist_utils\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.logger import MetricLogger\nfrom lavis.common.utils import is_url\nfrom lavis.models.base_model import BaseModel\nfrom lavis.models.vit import interpolate_pos_embed\nfrom transformers import BertTokenizer\n\n\nclass AlbefBase(BaseModel):\n    @classmethod\n    def init_tokenizer(cls):\n        return BertTokenizer.from_pretrained(\"bert-base-uncased\")\n\n    def load_from_pretrained(self, url_or_filename, rename_text_keys=True):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        if \"model\" in checkpoint:\n            state_dict = checkpoint[\"model\"]\n        else:\n            state_dict = checkpoint\n\n        state_dict[\"visual_encoder.pos_embed\"] = interpolate_pos_embed(\n            state_dict[\"visual_encoder.pos_embed\"], self.visual_encoder\n        )\n        if (\n            \"visual_encoder_m.pos_embed\" in self.state_dict().keys()\n            and \"visual_encoder_m.pos_embed\" in state_dict\n        ):\n            state_dict[\"visual_encoder_m.pos_embed\"] = interpolate_pos_embed(\n                state_dict[\"visual_encoder_m.pos_embed\"], self.visual_encoder_m\n            )\n\n        if rename_text_keys:\n            for key in list(state_dict.keys()):\n                if \"bert\" in key:\n                    new_key = key.replace(\"bert.\", \"\")\n                    state_dict[new_key] = state_dict[key]\n                    del state_dict[key]\n\n        for key in self.state_dict().keys():\n            if key in state_dict.keys():\n                if state_dict[key].shape != self.state_dict()[key].shape:\n                    del state_dict[key]\n\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n        return msg\n\n\ndef compute_sim_matrix(model, data_loader, **kwargs):\n    k_test = kwargs.pop(\"k_test\")\n\n    metric_logger = MetricLogger(delimiter=\"  \")\n    header = \"Evaluation:\"\n\n    logging.info(\"Computing features for evaluation...\")\n    start_time = time.time()\n\n    texts = data_loader.dataset.text\n    num_text = len(texts)\n    text_bs = 256\n    text_ids = []\n    text_embeds = []\n    text_atts = []\n    for i in range(0, num_text, text_bs):\n        text = texts[i : min(num_text, i + text_bs)]\n        text_input = model.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=35,\n            return_tensors=\"pt\",\n        ).to(model.device)\n        text_output = model.text_encoder.forward_text(text_input)\n        text_embed = F.normalize(\n            model.text_proj(text_output.last_hidden_state[:, 0, :])\n        )\n        text_embeds.append(text_embed)\n        text_ids.append(text_input.input_ids)\n        text_atts.append(text_input.attention_mask)\n\n    text_embeds = torch.cat(text_embeds, dim=0)\n    text_ids = torch.cat(text_ids, dim=0)\n    text_atts = torch.cat(text_atts, dim=0)\n    if hasattr(model.tokenizer, \"enc_token_id\"):\n        text_ids[:, 0] = model.tokenizer.enc_token_id\n\n    image_feats = []\n    image_embeds = []\n    for samples in data_loader:\n        image = samples[\"image\"]\n\n        image = image.to(model.device)\n        image_feat = model.visual_encoder.forward_features(image)\n        image_embed = model.vision_proj(image_feat[:, 0, :])\n        image_embed = F.normalize(image_embed, dim=-1)\n\n        image_feats.append(image_feat.cpu())\n        image_embeds.append(image_embed)\n\n    image_feats = torch.cat(image_feats, dim=0)\n    image_embeds = torch.cat(image_embeds, dim=0)\n\n    sims_matrix = image_embeds @ text_embeds.t()\n    score_matrix_i2t = torch.full(\n        (len(data_loader.dataset.image), len(texts)), -100.0\n    ).to(model.device)\n\n    num_tasks = dist_utils.get_world_size()\n    rank = dist_utils.get_rank()\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n        # topk_sim, topk_idx = sims.topk(k=config[\"k_test\"], dim=0)\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n\n        encoder_output = image_feats[start + i].repeat(k_test, 1, 1).to(model.device)\n        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(\n            model.device\n        )\n        output = model.text_encoder(\n            text_ids[topk_idx],\n            attention_mask=text_atts[topk_idx],\n            encoder_hidden_states=encoder_output,\n            encoder_attention_mask=encoder_att,\n            return_dict=True,\n        )\n        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]\n        score_matrix_i2t[start + i, topk_idx] = score + topk_sim\n\n    sims_matrix = sims_matrix.t()\n    score_matrix_t2i = torch.full(\n        (len(texts), len(data_loader.dataset.image)), -100.0\n    ).to(model.device)\n\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n        encoder_output = image_feats[topk_idx.cpu()].to(model.device)\n        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(\n            model.device\n        )\n        output = model.text_encoder(\n            text_ids[start + i].repeat(k_test, 1),\n            attention_mask=text_atts[start + i].repeat(k_test, 1),\n            encoder_hidden_states=encoder_output,\n            encoder_attention_mask=encoder_att,\n            return_dict=True,\n        )\n        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]\n        score_matrix_t2i[start + i, topk_idx] = score + topk_sim\n\n    if dist_utils.is_dist_avail_and_initialized():\n        dist.barrier()\n        torch.distributed.all_reduce(\n            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM\n        )\n        torch.distributed.all_reduce(\n            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM\n        )\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logging.info(\"Evaluation time {}\".format(total_time_str))\n\n    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()\n"
  },
  {
    "path": "lavis/models/albef_models/albef_classification.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport warnings\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.albef_models import AlbefBase\nfrom lavis.models.albef_models.albef_outputs import (\n    AlbefIntermediateOutput,\n    AlbefOutputWithLogits,\n)\nfrom lavis.models.base_model import MomentumDistilationMixin\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"albef_classification\")\nclass AlbefClassification(AlbefBase, MomentumDistilationMixin):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"ve\": \"configs/models/albef_classification_ve.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        num_classes,\n        momentum=0.995,\n        alpha=0.4,\n        use_distill=True,\n        max_txt_len=40,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n        self.max_txt_len = max_txt_len\n\n        self.use_distill = use_distill\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        hidden_size = text_encoder.config.hidden_size\n\n        if num_classes > 0:\n            self.cls_head = nn.Sequential(\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, num_classes),\n            )\n        else:\n            warnings.warn(\n                f\"Found num_classes=0, initializing {type(self)} without classifier.\"\n            )\n\n        if self.use_distill:\n            self.visual_encoder_m = deepcopy(self.visual_encoder)\n            self.text_encoder_m = deepcopy(self.text_encoder)\n            self.cls_head_m = deepcopy(self.cls_head)\n\n            self.momentum = momentum\n            self.alpha = alpha\n\n            self.model_pairs = [\n                [self.visual_encoder, self.visual_encoder_m],\n                [self.text_encoder, self.text_encoder_m],\n                [self.cls_head, self.cls_head_m],\n            ]\n\n            self.copy_params()\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)\n\n    def forward(self, samples, is_train=True):\n        sentences = samples[\"text_input\"]\n        sentences = self.tokenizer(\n            sentences,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        samples.update({\"tokenized_text\": sentences})\n\n        targets = samples[\"label\"]\n\n        image_embeds = self.visual_encoder.forward_features(samples[\"image\"])\n        encoder_output = self.text_encoder.forward_automask(\n            samples[\"tokenized_text\"], image_embeds\n        )\n\n        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])\n\n        if is_train:\n            if self.use_distill:\n                with torch.no_grad():\n                    self._momentum_update()\n\n                    image_embeds_m = self.visual_encoder_m(samples[\"image\"])\n                    encoder_output_m = self.text_encoder_m.forward_automask(\n                        samples[\"tokenized_text\"], image_embeds_m\n                    )\n\n                    prediction_m = self.cls_head_m(\n                        encoder_output_m.last_hidden_state[:, 0, :]\n                    )\n\n                alpha = self.alpha * self._rampup_factor(\n                    epoch=samples[\"epoch\"],\n                    iters=samples[\"iters\"],\n                    num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n                )\n\n                loss = (1 - alpha) * F.cross_entropy(\n                    prediction, targets\n                ) - alpha * torch.sum(\n                    F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),\n                    dim=1,\n                ).mean()\n            else:\n                loss = F.cross_entropy(prediction, targets)\n\n                image_embeds_m, encoder_output_m, prediction_m = None, None, None\n\n            # return {\"loss\": loss}\n            return AlbefOutputWithLogits(\n                loss=loss,\n                intermediate_output=AlbefIntermediateOutput(\n                    image_embeds=image_embeds,\n                    image_embeds_m=image_embeds_m,\n                    encoder_output=encoder_output,\n                    encoder_output_m=encoder_output_m,\n                ),\n                logits=prediction,\n                logits_m=prediction_m,\n            )\n        else:\n            return {\"predictions\": prediction, \"targets\": targets}\n\n    def predict(self, samples):\n        output = self.forward(samples, is_train=False)\n        return output\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        # text encoder + multimodal encoder\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        alpha = cfg.get(\"alpha\", 0.4)\n        momentum = cfg.get(\"momentum\", 0.995)\n        use_distill = cfg.get(\"use_distill\", True)\n        num_classes = cfg.get(\"num_classes\", -1)\n        max_txt_len = cfg.get(\"max_txt_len\", 40)\n\n        assert num_classes > 1, \"Invalid number of classes provided, found {}\".format(\n            num_classes\n        )\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            use_distill=use_distill,\n            alpha=alpha,\n            num_classes=num_classes,\n            momentum=momentum,\n            max_txt_len=max_txt_len,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/albef_models/albef_feature_extractor.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path\nfrom lavis.models.albef_models import AlbefBase\nfrom lavis.models.albef_models.albef_outputs import AlbefOutputFeatures\nfrom lavis.models.med import BertForMaskedLM\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\nfrom transformers import BertConfig\n\n\n@registry.register_model(\"albef_feature_extractor\")\nclass AlbefFeatureExtractor(AlbefBase):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/albef_feature_extractor.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=30):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.embed_dim = embed_dim\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.max_txt_len = max_txt_len\n\n        self.temp = nn.Parameter(0.07 * torch.ones([]))\n\n    @torch.no_grad()\n    def extract_features(self, samples, mode=\"multimodal\"):\n        \"\"\"\n        Extract features for multimodal or unimodal samples.\n\n        Args:\n            samples (dict): A dictionary of samples, containing the following keys:\n                - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.\n                    Raw images should be preprocessed before being passed to feature extractor.\n                - text_input (list): A list of strings containing the text, length B.\n            mode (str): The mode of feature extraction. Can be either \"multimodal\", \"text\" or \"image\".\n                If \"multimodal\", return image features and multimodal features;\n                if \"text\", return text features;\n                if \"image\", return image features.\n                Default: \"multimodal\".\n\n        Returns:\n            An AlbefOutputFeatures object, see lavis/models/albef_models/albef_outputs.py for details.\n\n        Examples:\n        ```python\n            >>> from PIL import Image\n            >>> from lavis.models import load_model_and_preprocess\n            >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n            >>> caption = \"a large fountain spewing water into the air\"\n            >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"albef_feature_extractor\", is_eval=True)\n            >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n            >>> text_input = txt_processors[\"eval\"](caption)\n\n            >>> sample = {\"image\": image, \"text_input\": [text_input]}\n\n            >>> features_multimodal = model.extract_features(sample)\n            >>> features_multimodal.keys()\n            odict_keys(['image_embeds', 'multimodal_embeds'])\n            >>> features_multimodal.image_embeds.shape\n            torch.Size([1, 197, 768])\n            >>> features_multimodal.multimodal_embeds.shape\n            torch.Size([1, 12, 768])\n\n            >>> features_text = model.extract_features(sample, mode=\"text\")\n            >>> features_text.keys()\n            odict_keys(['text_embeds', 'text_features'])\n            >>> features_text.text_embeds.shape\n            torch.Size([1, 12, 768])\n            >>> features_text.text_features.shape\n            torch.Size([1, 12, 256])\n\n            >>> features_image = model.extract_features(sample, mode=\"image\")\n            >>> features_image.keys()\n            odict_keys(['image_embeds', 'image_features'])\n            >>> features_image.image_embeds.shape\n            torch.Size([1, 197, 768])\n            >>> features_image.image_features.shape\n            torch.Size([1, 197, 256])\n        ```\n        \"\"\"\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n\n        if isinstance(mode, str):\n            mode = [mode]\n\n        for m in mode:\n            assert m in [\n                \"multimodal\",\n                \"image\",\n                \"text\",\n            ], \"mode must be one of [multimodal, image, text], but got {}\".format(m)\n\n        # initalize output\n        image_embeds, text_embeds, multimodal_embeds = None, None, None\n        image_features, text_features = None, None\n\n        if \"image\" in mode or \"multimodal\" in mode:\n            assert (\n                image is not None\n            ), \"image must be provided if mode is 'image' or 'multimodal'\"\n\n            image_embeds = self.visual_encoder.forward_features(image)\n            image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)\n\n        if \"text\" in mode or \"multimodal\" in mode:\n            assert (\n                caption is not None\n            ), \"text must be provided if mode is 'text' or 'multimodal'\"\n\n            text = self.tokenizer(\n                caption,\n                padding=True,\n                return_tensors=\"pt\",\n            ).to(self.device)\n\n            text_output = self.text_encoder.bert(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n                mode=\"text\",\n            )\n            text_embeds = text_output.last_hidden_state\n            text_features = F.normalize(self.text_proj(text_embeds), dim=-1)\n\n        if \"multimodal\" in mode:\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n                self.device\n            )\n\n            # forward the positve image-text pair\n            output = self.text_encoder.bert(\n                encoder_embeds=text_embeds,\n                attention_mask=text.attention_mask,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n                mode=\"fusion\",\n            )\n\n            multimodal_embeds = output.last_hidden_state\n\n        return AlbefOutputFeatures(\n            image_embeds=image_embeds,\n            image_embeds_proj=image_features,\n            text_embeds=text_embeds,\n            text_embeds_proj=text_features,\n            multimodal_embeds=multimodal_embeds,\n        )\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True)\n        config_text_encoder = BertConfig.from_json_file(\n            get_abs_path(cfg[\"med_config_path\"])\n        )\n        config_text_encoder.fusion_layer = 6\n        text_encoder = BertForMaskedLM.from_pretrained(\n            \"bert-base-uncased\", config=config_text_encoder\n        )\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        max_txt_len = cfg.get(\"max_txt_len\", 30)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            embed_dim=embed_dim,\n            max_txt_len=max_txt_len,\n        )\n\n        # load pre-trained weights\n        pretrain_path = cfg.get(\"pretrained\", None)\n        if pretrain_path is not None:\n            msg = model.load_from_pretrained(\n                url_or_filename=pretrain_path, rename_text_keys=False\n            )\n        else:\n            warnings.warn(\"No pretrained weights are loaded.\")\n\n        return model\n"
  },
  {
    "path": "lavis/models/albef_models/albef_nlvr.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path\nfrom lavis.models.albef_models import AlbefBase\nfrom lavis.models.albef_models.albef_outputs import AlbefIntermediateOutput, AlbefOutput\nfrom lavis.models.base_model import MomentumDistilationMixin\nfrom lavis.models.med import BertModel\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\nfrom transformers import BertConfig\n\n\n@registry.register_model(\"albef_nlvr\")\nclass AlbefNLVR(AlbefBase, MomentumDistilationMixin):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"nlvr\": \"configs/models/albef_nlvr.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        num_classes,\n        momentum=0.995,\n        alpha=0.4,\n        use_distill=True,\n        max_txt_len=40,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n        self.max_txt_len = max_txt_len\n\n        self.use_distill = use_distill\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        hidden_size = text_encoder.config.hidden_size\n        self.cls_head = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, num_classes),\n        )\n\n        self.share_cross_attention(self.text_encoder.encoder)\n\n        if self.use_distill:\n            self.visual_encoder_m = deepcopy(self.visual_encoder)\n            self.text_encoder_m = deepcopy(self.text_encoder)\n            self.cls_head_m = deepcopy(self.cls_head)\n\n            self.share_cross_attention(self.text_encoder_m.encoder)\n\n            self.momentum = momentum\n            self.alpha = alpha\n\n            self.model_pairs = [\n                [self.visual_encoder, self.visual_encoder_m],\n                [self.text_encoder, self.text_encoder_m],\n                [self.cls_head, self.cls_head_m],\n            ]\n\n            self.copy_params()\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))\n\n    def forward(self, samples, is_train=True):\n        \"\"\"\n        Forward function for training and evaluation.\n\n        Args:\n            samples (dict): a dict of input samples, which contains the following keys:\n                - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384.\n                - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384.\n                - text_input (list): list of strings, each string is a natural language sentence.\n                - label (torch.LongTensor): ground truth label with shape (batch_size,).\n            is_train (bool): whether the model is in training mode.\n                If True, the model will return the loss;\n                If False, the model will return the prediction.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"albef_nlvr\")\n            >>> samples = {\n            ...     \"image0\": torch.randn(2, 3, 384, 384),\n            ...     \"image1\": torch.randn(2, 3, 384, 384),\n            ...     \"text_input\": [\"there is a ferret in tall grass\", \"there are lips in one of the images\"],\n            ...     \"label\": torch.tensor([0, 1]),\n            ... }\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['intermediate_output', 'loss'])\n        \"\"\"\n        text = samples[\"text_input\"]\n        text = self.tokenizer(\n            text,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n\n        targets = samples[\"label\"]\n\n        image0 = samples[\"image0\"]\n        image1 = samples[\"image1\"]\n        images = torch.cat([image0, image1], dim=0)\n\n        image_embeds = self.visual_encoder.forward_features(images)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n        image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))\n\n        encoder_output = self.text_encoder(\n            text.input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=[image0_embeds, image1_embeds],\n            encoder_attention_mask=[\n                image_atts[: image0_embeds.size(0)],\n                image_atts[image0_embeds.size(0) :],\n            ],\n            return_dict=True,\n        )\n\n        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])\n\n        if is_train:\n            if self.use_distill:\n                with torch.no_grad():\n                    self._momentum_update()\n\n                    image_embeds_m = self.visual_encoder_m(images)\n                    image0_embeds_m, image1_embeds_m = torch.split(\n                        image_embeds_m, targets.size(0)\n                    )\n                    encoder_output_m = self.text_encoder(\n                        text.input_ids,\n                        attention_mask=text.attention_mask,\n                        encoder_hidden_states=[image0_embeds_m, image1_embeds_m],\n                        encoder_attention_mask=[\n                            image_atts[: image0_embeds_m.size(0)],\n                            image_atts[image0_embeds_m.size(0) :],\n                        ],\n                        return_dict=True,\n                    )\n\n                    prediction_m = self.cls_head_m(\n                        encoder_output_m.last_hidden_state[:, 0, :]\n                    )\n\n                alpha = self.alpha * self._rampup_factor(\n                    epoch=samples[\"epoch\"],\n                    iters=samples[\"iters\"],\n                    num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n                )\n\n                loss = (1 - alpha) * F.cross_entropy(\n                    prediction, targets\n                ) - alpha * torch.sum(\n                    F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),\n                    dim=1,\n                ).mean()\n            else:\n                loss = F.cross_entropy(prediction, targets)\n\n                encoder_output_m = None\n                image0_embeds_m, image1_embeds_m = None, None\n\n            # return {\"loss\": loss}\n            return AlbefOutput(\n                loss=loss,\n                intermediate_output=AlbefIntermediateOutput(\n                    image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0),\n                    image_embeds_m=torch.stack(\n                        [image0_embeds_m, image1_embeds_m], dim=0\n                    ),\n                    encoder_output=encoder_output,\n                    encoder_output_m=encoder_output_m,\n                ),\n            )\n        else:\n            return {\"predictions\": prediction, \"targets\": targets}\n\n    def share_cross_attention(self, model):\n        for i in range(6):\n            layer_num = 6 + i * 2\n            modules_0 = model.layer[layer_num].crossattention.self._modules\n            modules_1 = model.layer[layer_num + 1].crossattention.self._modules\n\n            for name in modules_0.keys():\n                if \"key\" in name or \"value\" in name:\n                    module_0 = modules_0[name]\n                    module_1 = modules_1[name]\n                    if hasattr(module_0, \"weight\"):\n                        module_0.weight = module_1.weight\n                        if hasattr(module_0, \"bias\"):\n                            module_0.bias = module_1.bias\n\n    def predict(self, samples):\n        output = self.forward(samples, is_train=False)\n        return output\n\n    def load_from_pretrained(self, url_or_filename, use_distill=True):\n        _, msg = super().load_from_pretrained(url_or_filename)\n\n        if use_distill and any([\"_m\" in k for k in msg.missing_keys]):\n            # this is required when initializing the model from TA pre-trained weights\n            self.copy_params()\n\n        return msg\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        # text encoder + multimodal encoder\n        bert_config = BertConfig.from_json_file(get_abs_path(cfg[\"med_config_path\"]))\n        bert_config.num_hidden_layers = 18\n\n        text_encoder = BertModel.from_pretrained(\n            \"bert-base-uncased\", config=bert_config, add_pooling_layer=False\n        )\n\n        alpha = cfg.get(\"alpha\", 0.4)\n        momentum = cfg.get(\"momentum\", 0.995)\n        use_distill = cfg.get(\"use_distill\", True)\n        num_classes = cfg.get(\"num_classes\", -1)\n        max_txt_len = cfg.get(\"max_txt_len\", 40)\n\n        assert num_classes > 1, \"Invalid number of classes provided, found {}\".format(\n            num_classes\n        )\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            use_distill=use_distill,\n            alpha=alpha,\n            num_classes=num_classes,\n            momentum=momentum,\n            max_txt_len=max_txt_len,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/albef_models/albef_outputs.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n    ModelOutput,\n)\n\n\n@dataclass\nclass AlbefSimilarity(ModelOutput):\n    sim_i2t: torch.FloatTensor = None\n    sim_t2i: torch.FloatTensor = None\n\n    sim_i2t_m: Optional[torch.FloatTensor] = None\n    sim_t2i_m: Optional[torch.FloatTensor] = None\n\n    sim_i2t_targets: Optional[torch.FloatTensor] = None\n    sim_t2i_targets: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass AlbefIntermediateOutput(ModelOutput):\n    # uni-modal features\n    image_embeds: torch.FloatTensor = None\n    text_embeds: Optional[torch.FloatTensor] = None\n\n    image_embeds_m: Optional[torch.FloatTensor] = None\n    text_embeds_m: Optional[torch.FloatTensor] = None\n\n    # intermediate outputs of multimodal encoder\n    encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n    encoder_output_m: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n    encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n\n    itm_logits: Optional[torch.FloatTensor] = None\n    itm_labels: Optional[torch.LongTensor] = None\n\n    # intermediate outputs of multimodal decoder\n    decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None\n    decoder_labels: Optional[torch.LongTensor] = None\n\n\n@dataclass\nclass AlbefOutput(ModelOutput):\n    # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.\n    sims: Optional[AlbefSimilarity] = None\n\n    intermediate_output: AlbefIntermediateOutput = None\n\n    loss: Optional[torch.FloatTensor] = None\n\n    loss_itc: Optional[torch.FloatTensor] = None\n\n    loss_itm: Optional[torch.FloatTensor] = None\n\n    loss_mlm: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass AlbefOutputWithLogits(AlbefOutput):\n    logits: torch.FloatTensor = None\n    logits_m: torch.FloatTensor = None\n\n\n@dataclass\nclass AlbefOutputFeatures(ModelOutput):\n    \"\"\"\n    Data class of features from AlbefFeatureExtractor.\n\n    Args:\n        image_embeds: `torch.FloatTensor` of shape `(batch_size, num_patches+1, embed_dim)`, `optional`\n        image_features: `torch.FloatTensor` of shape `(batch_size, num_patches+1, feature_dim)`, `optional`\n        text_embeds: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, embed_dim)`, `optional`\n        text_features: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, feature_dim)`, `optional`\n\n        The first embedding or feature is for the [CLS] token.\n\n        Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    image_embeds_proj: Optional[torch.FloatTensor] = None\n\n    text_embeds: Optional[torch.FloatTensor] = None\n    text_embeds_proj: Optional[torch.FloatTensor] = None\n\n    multimodal_embeds: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "lavis/models/albef_models/albef_pretrain.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path\nfrom lavis.models.albef_models import AlbefBase\nfrom lavis.models.albef_models.albef_outputs import (\n    AlbefIntermediateOutput,\n    AlbefOutput,\n    AlbefSimilarity,\n)\nfrom lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin\nfrom lavis.models.med import BertForMaskedLM\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\nfrom transformers import BertConfig\n\n\n@registry.register_model(\"albef_pretrain\")\nclass AlbefPretrain(AlbefBase, MomentumDistilationMixin, SharedQueueMixin):\n    \"\"\"\n    ALBEF pretrain model.\n\n    Supported model types:\n        - base: ALBEF base model used for pretraining.\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/albef_pretrain_base.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        queue_size,\n        embed_dim=256,\n        mlm_mask_prob=0.15,\n        temp=0.07,\n        momentum=0.995,\n        alpha=0.4,\n        max_txt_len=30,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.embed_dim = embed_dim\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n        # create the momentum encoder\n        self.visual_encoder_m = deepcopy(self.visual_encoder)\n        self.text_encoder_m = deepcopy(self.text_encoder)\n\n        self.vision_proj_m = deepcopy(self.vision_proj)\n        self.text_proj_m = deepcopy(self.text_proj)\n\n        self.model_pairs = [\n            [self.visual_encoder, self.visual_encoder_m],\n            [self.text_encoder, self.text_encoder_m],\n            [self.vision_proj, self.vision_proj_m],\n            [self.text_proj, self.text_proj_m],\n        ]\n        self.copy_params()\n\n        # create the queue\n        self.register_buffer(\"image_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"text_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"queue_ptr\", torch.zeros(1, dtype=torch.long))\n\n        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)\n        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)\n\n        self.queue_size = queue_size\n        self.momentum = momentum\n        self.temp = nn.Parameter(temp * torch.ones([]))\n\n        self.alpha = alpha\n        self.max_txt_len = max_txt_len\n\n        self.mlm_probability = mlm_mask_prob\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))\n\n    def forward(self, samples):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. Default: H=224, W=224.\n                - text_input (list): A list of length batch_size, each element is a string of text/caption.\n                - epoch (int): The current epoch.\n                - iters (int): The current iteration.\n                - num_iters_per_epoch (int): The number of iterations per epoch.\n\n        Returns:\n            BlipOutput: A BlipOutput object containing loss and intermediate output. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"albef_pretrain\")\n            >>> images = torch.randn(4, 3, 224, 224)\n            >>> text_input = [\"caption of image 1\", \"another caption of image 1\", \"caption of image 2\", \"caption of image 3\"]\n            >>> samples = {\"image\": images, \"text_input\": text_input, \"epoch\": 0, \"iters\": 0, \"num_iters_per_epoch\": 100}\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm', 'loss_mlm'])\n        \"\"\"\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n\n        alpha = self.alpha * self._rampup_factor(\n            epoch=samples[\"epoch\"],\n            iters=samples[\"iters\"],\n            num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n        )\n\n        with torch.no_grad():\n            self.temp.clamp_(0.001, 0.5)\n\n        image_embeds = self.visual_encoder.forward_features(image)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        text = self.tokenizer(\n            caption,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n\n        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n\n        text_output = self.text_encoder.bert(\n            text.input_ids,\n            attention_mask=text.attention_mask,\n            return_dict=True,\n            mode=\"text\",\n        )\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)\n\n        # get momentum features\n        with torch.no_grad():\n            self._momentum_update()\n            image_embeds_m = self.visual_encoder_m(image)\n            image_feat_m = F.normalize(\n                self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1\n            )\n            image_feat_all = torch.cat(\n                [image_feat_m.t(), self.image_queue.clone().detach()], dim=1\n            )\n            text_output_m = self.text_encoder_m.bert(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n                mode=\"text\",\n            )\n            text_embeds_m = text_output_m.last_hidden_state\n            text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)\n            text_feat_all = torch.cat(\n                [text_feat_m.t(), self.text_queue.clone().detach()], dim=1\n            )\n\n            sim_i2t_m = image_feat_m @ text_feat_all / self.temp\n            sim_t2i_m = text_feat_m @ image_feat_all / self.temp\n\n            sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)\n            sim_targets.fill_diagonal_(1)\n\n            sim_i2t_targets = (\n                alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets\n            )\n            sim_t2i_targets = (\n                alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets\n            )\n\n        sim_i2t = image_feat @ text_feat_all / self.temp\n        sim_t2i = text_feat @ image_feat_all / self.temp\n\n        loss_i2t = -torch.sum(\n            F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1\n        ).mean()\n        loss_t2i = -torch.sum(\n            F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1\n        ).mean()\n\n        loss_itc = (loss_i2t + loss_t2i) / 2\n\n        self._dequeue_and_enqueue(image_feat_m, text_feat_m)\n\n        # forward the positve image-text pair\n        encoder_output_pos = self.text_encoder.bert(\n            encoder_embeds=text_embeds,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n        with torch.no_grad():\n            bs = image.size(0)\n\n            weights_i2t = sim_i2t[:, :bs].clone()\n            weights_t2i = sim_t2i[:, :bs].clone()\n\n            weights_i2t.fill_diagonal_(-np.Inf)\n            weights_t2i.fill_diagonal_(-np.Inf)\n\n            weights_i2t = F.softmax(weights_i2t, dim=1)\n            weights_t2i = F.softmax(weights_t2i, dim=1)\n\n        # select a negative image for each text\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n            image_embeds_neg.append(image_embeds[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n        # select a negative text for each image\n        text_embeds_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n            text_embeds_neg.append(text_embeds[neg_idx])\n            text_atts_neg.append(text.attention_mask[neg_idx])\n        text_embeds_neg = torch.stack(text_embeds_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)\n        text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)\n\n        image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)\n        image_atts_all = torch.cat([image_atts, image_atts], dim=0)\n\n        encoder_output_neg = self.text_encoder.bert(\n            encoder_embeds=text_embeds_all,\n            attention_mask=text_atts_all,\n            encoder_hidden_states=image_embeds_all,\n            encoder_attention_mask=image_atts_all,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        vl_embeddings = torch.cat(\n            [\n                encoder_output_pos.last_hidden_state[:, 0, :],\n                encoder_output_neg.last_hidden_state[:, 0, :],\n            ],\n            dim=0,\n        )\n        itm_logits = self.itm_head(vl_embeddings)\n\n        itm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(self.device)\n        loss_itm = F.cross_entropy(itm_logits, itm_labels)\n\n        # MLM\n        input_ids = text.input_ids.clone()\n        labels = input_ids.clone()\n\n        probability_matrix = torch.full(labels.shape, self.mlm_probability)\n        input_ids, labels = self.mask(\n            input_ids,\n            self.text_encoder.config.vocab_size,\n            self.device,\n            targets=labels,\n            probability_matrix=probability_matrix,\n        )\n\n        with torch.no_grad():\n            logits_m = self.text_encoder_m(\n                input_ids,\n                attention_mask=text.attention_mask,\n                encoder_hidden_states=image_embeds_m,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n                return_logits=True,\n            )\n        mlm_output = self.text_encoder(\n            input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n            labels=labels,\n            soft_labels=F.softmax(logits_m, dim=-1),\n            alpha=alpha,\n        )\n        loss_mlm = mlm_output.loss\n\n        return AlbefOutput(\n            loss=loss_itc + loss_itm + loss_mlm,\n            loss_itc=loss_itc,\n            loss_itm=loss_itm,\n            loss_mlm=loss_mlm,\n            sims=AlbefSimilarity(\n                sim_i2t=sim_i2t,\n                sim_t2i=sim_t2i,\n                sim_i2t_m=sim_i2t_m,\n                sim_t2i_m=sim_t2i_m,\n                sim_i2t_targets=sim_i2t_targets,\n                sim_t2i_targets=sim_t2i_targets,\n            ),\n            intermediate_output=AlbefIntermediateOutput(\n                image_embeds=image_embeds,\n                image_embeds_m=image_embeds_m,\n                text_embeds=text_embeds,\n                text_embeds_m=text_embeds_m,\n                encoder_output=encoder_output_pos,\n                encoder_output_neg=encoder_output_neg,\n                itm_logits=itm_logits,\n                itm_labels=itm_labels,\n            ),\n        )\n\n    def mask(\n        self,\n        input_ids,\n        vocab_size,\n        device,\n        targets=None,\n        masked_indices=None,\n        probability_matrix=None,\n    ):\n        \"\"\"\n        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.\n        \"\"\"\n        if masked_indices is None:\n            masked_indices = torch.bernoulli(probability_matrix).bool()\n\n        masked_indices[input_ids == self.tokenizer.pad_token_id] = False\n        masked_indices[input_ids == self.tokenizer.cls_token_id] = False\n\n        if targets is not None:\n            targets[~masked_indices] = -100  # We only compute loss on masked tokens\n\n        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n        indices_replaced = (\n            torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices\n        )\n        input_ids[indices_replaced] = self.tokenizer.mask_token_id\n\n        # 10% of the time, we replace masked input tokens with random word\n        indices_random = (\n            torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()\n            & masked_indices\n            & ~indices_replaced\n        )\n        random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(\n            device\n        )\n        input_ids[indices_random] = random_words[indices_random]\n        # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n\n        if targets is not None:\n            return input_ids, targets\n        else:\n            return input_ids\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True)\n        config_text_encoder = BertConfig.from_json_file(\n            get_abs_path(cfg[\"med_config_path\"])\n        )\n        config_text_encoder.fusion_layer = 6\n        text_encoder = BertForMaskedLM.from_pretrained(\n            \"bert-base-uncased\", config=config_text_encoder\n        )\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        momentum = cfg.get(\"momentum\", 0.995)\n        alpha = cfg.get(\"alpha\", 0.4)\n        mlm_mask_prob = cfg.get(\"mlm_mask_prob\", 0.15)\n        temp = cfg.get(\"temp\", 0.07)\n        max_txt_len = cfg.get(\"max_txt_len\", 30)\n        queue_size = cfg.get(\"queue_size\", 65536)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            queue_size=queue_size,\n            embed_dim=embed_dim,\n            mlm_mask_prob=mlm_mask_prob,\n            temp=temp,\n            momentum=momentum,\n            alpha=alpha,\n            max_txt_len=max_txt_len,\n        )\n\n        return model\n"
  },
  {
    "path": "lavis/models/albef_models/albef_retrieval.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.albef_models import AlbefBase, compute_sim_matrix\nfrom lavis.models.albef_models.albef_outputs import (\n    AlbefIntermediateOutput,\n    AlbefOutput,\n    AlbefSimilarity,\n)\nfrom lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"albef_retrieval\")\nclass AlbefRetrieval(AlbefBase, MomentumDistilationMixin, SharedQueueMixin):\n    \"\"\"\n    ALBEF retrieval model.\n\n    Supported model types:\n        - coco: fine-tuned ALBEF base model on COCO dataset (Karparthy split).\n        - flickr: fine-tuned ALBEF base model on Flickr30k dataset.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"albef_retrieval\", \"coco\")\n        >>> model = load_model(\"albef_retrieval\", \"flickr\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"coco\": \"configs/models/albef_retrieval_coco.yaml\",\n        \"flickr\": \"configs/models/albef_retrieval_flickr.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        queue_size,\n        embed_dim=256,\n        temp=0.07,\n        use_distill=True,\n        momentum=0.995,\n        alpha=0.4,\n        max_txt_len=30,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n        # create the momentum encoder\n        self.visual_encoder_m = deepcopy(self.visual_encoder)\n        self.text_encoder_m = deepcopy(self.text_encoder)\n\n        self.vision_proj_m = deepcopy(self.vision_proj)\n        self.text_proj_m = deepcopy(self.text_proj)\n\n        self.model_pairs = [\n            [self.visual_encoder, self.visual_encoder_m],\n            [self.text_encoder, self.text_encoder_m],\n            [self.vision_proj, self.vision_proj_m],\n            [self.text_proj, self.text_proj_m],\n        ]\n        self.copy_params()\n\n        # create the queue\n        self.register_buffer(\"image_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"text_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"idx_queue\", torch.full((1, queue_size), -100))\n        self.register_buffer(\"queue_ptr\", torch.zeros(1, dtype=torch.long))\n\n        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)\n        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)\n\n        self.queue_size = queue_size\n        self.momentum = momentum\n        self.temp = nn.Parameter(temp * torch.ones([]))\n\n        self.alpha = alpha\n        self.max_txt_len = max_txt_len\n        self.use_distill = use_distill\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))\n\n    def forward(self, samples):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images.\n                - text_input (list): A list of length batch_size, each element is a string of text/caption.\n                - image_id (torch.Tensor): A tensor of shape (batch_size, ). The image ids, used to identify same images in batch.\n                - epoch (int): The current epoch.\n                - iters (int): The current iteration.\n                - num_iters_per_epoch (int): The number of iterations per epoch.\n\n        Returns:\n            BlipOutput: A BlipOutput object. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"albef_retrieval\", \"coco\")\n            >>> images = torch.randn(4, 3, 384, 384)\n            >>> text_input = [\"caption of image 1\", \"another caption of image 1\", \"caption of image 2\", \"caption of image 3\"]\n            >>> image_id = torch.tensor([1, 1, 2, 3])\n            >>> samples = {\"image\": images, \"text_input\": text_input, \"image_id\": image_id, \"epoch\": 0, \"iters\": 0, \"num_iters_per_epoch\": 100}\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm'])\n        \"\"\"\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n        idx = samples[\"image_id\"]\n\n        alpha = self.alpha * self._rampup_factor(\n            epoch=samples[\"epoch\"],\n            iters=samples[\"iters\"],\n            num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n        )\n\n        with torch.no_grad():\n            self.temp.clamp_(0.001, 0.5)\n\n        image_embeds = self.visual_encoder.forward_features(image)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n\n        text = self.tokenizer(\n            caption,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n\n        text_output = self.text_encoder.forward_text(text)\n\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)\n\n        idx = idx.view(-1, 1)\n        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)\n        pos_idx = torch.eq(idx, idx_all).float()\n        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)\n\n        with torch.no_grad():\n            self._momentum_update()\n            image_embeds_m = self.visual_encoder_m(image)\n            image_feat_m = F.normalize(\n                self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1\n            )\n            image_feat_all = torch.cat(\n                [image_feat_m.t(), self.image_queue.clone().detach()], dim=1\n            )\n            text_output_m = self.text_encoder_m.forward_text(text)\n            text_embeds_m = text_output_m.last_hidden_state\n            text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)\n            text_feat_all = torch.cat(\n                [text_feat_m.t(), self.text_queue.clone().detach()], dim=1\n            )\n\n            if self.use_distill:\n                sim_i2t_m = image_feat_m @ text_feat_all / self.temp\n                sim_t2i_m = text_feat_m @ image_feat_all / self.temp\n\n                sim_i2t_targets = (\n                    alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets\n                )\n                sim_t2i_targets = (\n                    alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets\n                )\n\n        sim_i2t = image_feat @ text_feat_all / self.temp\n        sim_t2i = text_feat @ image_feat_all / self.temp\n\n        if self.use_distill:\n            loss_i2t = -torch.sum(\n                F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1\n            ).mean()\n            loss_t2i = -torch.sum(\n                F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1\n            ).mean()\n        else:\n            loss_i2t = -torch.sum(\n                F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1\n            ).mean()\n            loss_t2i = -torch.sum(\n                F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1\n            ).mean()\n\n        loss_itc = (loss_i2t + loss_t2i) / 2\n\n        self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)\n\n        encoder_output_pos = self.text_encoder(\n            encoder_embeds=text_embeds,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        with torch.no_grad():\n            bs = image.size(0)\n            weights_i2t = F.softmax(sim_i2t[:, :bs] + 1e-4, dim=1)\n            weights_t2i = F.softmax(sim_t2i[:, :bs] + 1e-4, dim=1)\n\n            mask = torch.eq(idx, idx.T)\n            weights_i2t.masked_fill_(mask, 0)\n            weights_t2i.masked_fill_(mask, 0)\n\n        # select a negative image for each text\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n            image_embeds_neg.append(image_embeds[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n        # select a negative text for each image\n        text_embeds_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n            text_embeds_neg.append(text_embeds[neg_idx])\n            text_atts_neg.append(text.attention_mask[neg_idx])\n        text_embeds_neg = torch.stack(text_embeds_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)\n        text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)\n\n        image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)\n        image_atts_all = torch.cat([image_atts, image_atts], dim=0)\n\n        encoder_output_neg = self.text_encoder(\n            encoder_embeds=text_embeds_all,\n            attention_mask=text_atts_all,\n            encoder_hidden_states=image_embeds_all,\n            encoder_attention_mask=image_atts_all,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        vl_embeddings = torch.cat(\n            [\n                encoder_output_pos.last_hidden_state[:, 0, :],\n                encoder_output_neg.last_hidden_state[:, 0, :],\n            ],\n            dim=0,\n        )\n        itm_logits = self.itm_head(vl_embeddings)\n\n        itm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(self.device)\n        loss_itm = F.cross_entropy(itm_logits, itm_labels)\n\n        return AlbefOutput(\n            loss=loss_itc + loss_itm,\n            loss_itc=loss_itc,\n            loss_itm=loss_itm,\n            sims=AlbefSimilarity(\n                sim_i2t=sim_i2t,\n                sim_t2i=sim_t2i,\n                sim_i2t_m=sim_i2t_m,\n                sim_t2i_m=sim_t2i_m,\n                sim_i2t_targets=sim_i2t_targets,\n                sim_t2i_targets=sim_t2i_targets,\n            ),\n            intermediate_output=AlbefIntermediateOutput(\n                image_embeds=image_embeds,\n                image_embeds_m=image_embeds_m,\n                text_embeds=text_embeds,\n                text_embeds_m=text_embeds_m,\n                encoder_output=encoder_output_pos,\n                encoder_output_neg=encoder_output_neg,\n                itm_logits=itm_logits,\n                itm_labels=itm_labels,\n            ),\n        )\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=False)\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        momentum = cfg.get(\"momentum\", 0.995)\n        alpha = cfg.get(\"alpha\", 0.4)\n        temp = cfg.get(\"temp\", 0.07)\n        max_txt_len = cfg.get(\"max_txt_len\", 30)\n        queue_size = cfg.get(\"queue_size\", 0)\n        use_distill = cfg.get(\"use_distill\", True)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            queue_size=queue_size,\n            embed_dim=embed_dim,\n            temp=temp,\n            momentum=momentum,\n            alpha=alpha,\n            max_txt_len=max_txt_len,\n            use_distill=use_distill,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n\n    def compute_sim_matrix(self, data_loader, task_cfg):\n        \"\"\"\n        Compute similarity i2t, t2i matrix for the given data loader.\n        \"\"\"\n        k_test = task_cfg.k_test\n\n        return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)\n"
  },
  {
    "path": "lavis/models/albef_models/albef_vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path, is_url\nfrom lavis.models.albef_models import AlbefBase\nfrom lavis.models.albef_models.albef_outputs import AlbefIntermediateOutput, AlbefOutput\nfrom lavis.models.base_model import MomentumDistilationMixin, tile\nfrom lavis.models.med import BertConfig, BertLMHeadModel, XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed\nfrom lavis.common.dist_utils import download_cached_file\n\n\n@registry.register_model(\"albef_vqa\")\nclass AlbefVQA(AlbefBase, MomentumDistilationMixin):\n    \"\"\"\n    ALBEF VQA models.\n\n    Supported model types:\n        - base: vqa model initialized with pre-trained ALBEF base model on 115M image-text pairs after CapFilt; not fine-tuned.\n        - vqav2: fine-tuned ALBEF base model on VQA v2.0 dataset.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"albef_vqa\", \"vqav2\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"vqav2\": \"configs/models/albef_vqav2.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        text_decoder,\n        use_distill=True,\n        momentum=0.995,\n        alpha=0.4,\n        max_txt_len=35,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n        self.max_txt_len = max_txt_len\n\n        self.use_distill = use_distill\n\n        self.visual_encoder = image_encoder\n\n        self.text_encoder = text_encoder\n        self.text_decoder = text_decoder\n\n        if self.use_distill:\n            self.visual_encoder_m = deepcopy(self.visual_encoder)\n            self.text_encoder_m = deepcopy(self.text_encoder)\n            self.text_decoder_m = deepcopy(self.text_decoder)\n\n            self.momentum = momentum\n            self.alpha = alpha\n\n            self.model_pairs = [\n                [self.visual_encoder, self.visual_encoder_m],\n                [self.text_encoder, self.text_encoder_m],\n                [self.text_decoder, self.text_decoder_m],\n            ]\n\n            self.copy_params()\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)\n\n    def forward(self, samples):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (list): A list of strings, each string is a question\n                - answer (list): A list of strings, each string is an answer\n                - weight (torch.Tensor): A tensor used to weigh each answer in the loss computation.\n                   The shape of the tensor is (sum(n_answers),)\n                - n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers\n                     for each question in the batch.\n\n        Returns:\n            An AlbefOutput object containing loss and intermediate outputs;\n            see lavis/models/albef_models/albef_outputs.py for more details.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"albef_vqa\")\n            >>> samples = {\n            ...     \"image\": torch.rand(2, 3, 384, 384),\n            ...     \"text_input\": [\"What is this?\", \"What is that?\"],\n            ...     \"answer\": [\"cat\", \"cat\", \"dog\"],\n            ...     \"weight\": torch.tensor([1.0, 1.0, 1.0]),\n            ...     \"n_answers\": torch.tensor([2, 1]),\n            ...     \"epoch\": 0, \"iters\": 0, \"num_iters_per_epoch\": 1000,\n            ... }\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['intermediate_output', 'loss'])\n        \"\"\"\n        (\n            encoder_output,\n            encoder_output_m,\n            image_embeds,\n            image_embeds_m,\n        ) = self.forward_encoder(samples)\n        loss, decoder_output, decoder_targets = self.forward_decoder(\n            samples, encoder_out=(encoder_output, encoder_output_m)\n        )\n\n        return AlbefOutput(\n            loss=loss,\n            intermediate_output=AlbefIntermediateOutput(\n                image_embeds=image_embeds,\n                image_embeds_m=image_embeds_m,\n                encoder_output=encoder_output,\n                encoder_output_m=encoder_output_m,\n                decoder_output=decoder_output,\n                decoder_labels=decoder_targets,\n            ),\n        )\n\n    def forward_encoder(self, samples):\n        questions = samples[\"text_input\"]\n        questions = self.tokenizer(\n            questions,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        samples.update({\"tokenized_text\": questions})\n\n        image_embeds = self.visual_encoder.forward_features(samples[\"image\"])\n        encoder_output = self.text_encoder.forward_automask(\n            tokenized_text=samples[\"tokenized_text\"], visual_embeds=image_embeds\n        )\n\n        if self.use_distill:\n            self._momentum_update()\n            with torch.no_grad():\n                image_embeds_m = self.visual_encoder_m(samples[\"image\"])\n                encoder_output_m = self.text_encoder_m.forward_automask(\n                    tokenized_text=samples[\"tokenized_text\"],\n                    visual_embeds=image_embeds_m,\n                )\n        else:\n            encoder_output_m = None\n            image_embeds_m = None\n\n        return encoder_output, encoder_output_m, image_embeds, image_embeds_m\n\n    def forward_decoder(self, samples, encoder_out, **kwargs):\n        answers = self.tokenizer(\n            samples[\"answer\"], padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n        answer_targets = answers.input_ids.masked_fill(\n            answers.input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        question_states = []\n        question_atts = []\n\n        question = samples[\"tokenized_text\"]\n        question_output, question_output_m = encoder_out\n\n        for b, n in enumerate(samples[\"n_answers\"]):\n            question_states += [question_output.last_hidden_state[b]] * n\n            question_atts += [question.attention_mask[b]] * n\n\n        question_states = torch.stack(question_states, dim=0)\n        question_atts = torch.stack(question_atts, dim=0)\n\n        if self.use_distill:\n            with torch.no_grad():\n                question_states_m = []\n                for b, n in enumerate(samples[\"n_answers\"]):\n                    question_states_m += [question_output_m.last_hidden_state[b]] * n\n                question_states_m = torch.stack(question_states_m, 0)\n\n                logits_m = self.text_decoder_m(\n                    answers.input_ids,\n                    attention_mask=answers.attention_mask,\n                    encoder_hidden_states=question_states_m,\n                    encoder_attention_mask=question_atts,\n                    return_logits=True,\n                )\n\n                alpha = self.alpha * self._rampup_factor(\n                    epoch=samples[\"epoch\"],\n                    iters=samples[\"iters\"],\n                    num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n                )\n\n        answer_output = self.text_decoder(\n            answers.input_ids,\n            attention_mask=answers.attention_mask,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            labels=answer_targets,\n            soft_labels=F.softmax(logits_m, dim=-1),\n            alpha=alpha,\n            return_dict=True,\n            reduction=\"none\",\n        )\n\n        loss = samples[\"weight\"] * answer_output.loss\n        bsz = samples[\"image\"].size(0)\n\n        loss = loss.sum() / bsz\n\n        return loss, answer_output, answer_targets\n\n    def predict_answers(self, samples, answer_list, num_ans_candidates=128, **kwargs):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (str or [str]): String or a list of strings, each string is a question.\n                                             The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.\n            num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability.\n            answer_list (list): A list of strings, each string is an answer.\n\n        Returns:\n            List: A list of strings, each string is an answer.\n\n        Examples:\n            >>> from PIL import Image\n            >>> from lavis.models import load_model_and_preprocess\n            >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"albef_vqa\", \"vqav2\")\n            >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n            >>> question = \"Which city is this photo taken?\"\n            >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n            >>> question = txt_processors[\"eval\"](question)\n            >>> samples = {\"image\": image, \"text_input\": [question]}\n            >>> answer_list = [\"Singapore\", \"London\", \"Palo Alto\", \"Tokyo\"]\n            >>> answers = model.predict_answers(samples, answer_list=answer_list)\n            >>> answers\n            ['Singapore']\n        \"\"\"\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n\n        assert len(samples[\"text_input\"]) == samples[\"image\"].size(\n            0\n        ), \"The number of questions must be equal to the batch size.\"\n\n        num_ans_candidates = min(num_ans_candidates, len(answer_list))\n\n        return self.rank_answers(\n            samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates\n        )\n\n    def rank_answers(self, samples, answer_list, num_ans_candidates):\n        \"\"\"\n        Generate the first token of answers using decoder and select ${num_ans_candidates}\n        most probable ones. Then select answers from answer list, which start with the probable tokens.\n        Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss.\n        Return the answers that minimize the losses as result.\n\n        \"\"\"\n        answer_candidates = self.tokenizer(\n            answer_list, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n        # answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id\n\n        answer_ids = answer_candidates.input_ids\n        answer_atts = answer_candidates.attention_mask\n\n        question_output, _, _, _ = self.forward_encoder(samples)\n        question_states = question_output.last_hidden_state\n\n        tokenized_question = samples[\"tokenized_text\"]\n        question_atts = tokenized_question.attention_mask\n\n        num_ques = question_states.size(0)\n        start_ids = answer_ids[0, 0].repeat(num_ques, 1)  # bos token\n\n        start_output = self.text_decoder(\n            start_ids,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            return_dict=True,\n            reduction=\"none\",\n        )\n        logits = start_output.logits[:, 0, :]  # first token's logit\n\n        # topk_probs: top-k probability\n        # topk_ids: [num_question, k]\n        answer_first_token = answer_ids[:, 1]\n        prob_first_token = F.softmax(logits, dim=1).index_select(\n            dim=1, index=answer_first_token\n        )\n        topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1)\n\n        # answer input: [num_question*k, answer_len]\n        input_ids = []\n        input_atts = []\n        for b, topk_id in enumerate(topk_ids):\n            input_ids.append(answer_ids.index_select(dim=0, index=topk_id))\n            input_atts.append(answer_atts.index_select(dim=0, index=topk_id))\n        input_ids = torch.cat(input_ids, dim=0)\n        input_atts = torch.cat(input_atts, dim=0)\n\n        targets_ids = input_ids.masked_fill(\n            input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        # repeat encoder's output for top-k answers\n        question_states = tile(question_states, 0, num_ans_candidates)\n        question_atts = tile(question_atts, 0, num_ans_candidates)\n\n        output = self.text_decoder(\n            input_ids,\n            attention_mask=input_atts,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            labels=targets_ids,\n            return_dict=True,\n            reduction=\"none\",\n        )\n\n        log_probs_sum = -output.loss\n        log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates)\n\n        max_topk_ids = log_probs_sum.argmax(dim=1)\n        max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]\n\n        answers = [answer_list[max_id] for max_id in max_ids]\n\n        return answers\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        config_decoder = BertConfig.from_json_file(get_abs_path(cfg[\"med_config_path\"]))\n        config_decoder.fusion_layer = 0\n        config_decoder.num_hidden_layers = 6\n        text_decoder = BertLMHeadModel.from_pretrained(\n            \"bert-base-uncased\", config=config_decoder\n        )\n\n        alpha = cfg.get(\"alpha\", 0.4)\n        momentum = cfg.get(\"momentum\", 0.995)\n        use_distill = cfg.get(\"use_distill\", True)\n        max_txt_len = cfg.get(\"max_txt_len\", 25)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            text_decoder=text_decoder,\n            use_distill=use_distill,\n            momentum=momentum,\n            alpha=alpha,\n            max_txt_len=max_txt_len,\n        )\n\n        # load pre-trained weights\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n\n    def load_from_pretrained(self, url_or_filename):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        if \"model\" in checkpoint:\n            state_dict = checkpoint[\"model\"]\n        else:\n            state_dict = checkpoint\n\n        # reshape positional embedding to accomodate for image resolution change\n        pos_embed_reshaped = interpolate_pos_embed(\n            state_dict[\"visual_encoder.pos_embed\"], self.visual_encoder\n        )\n        state_dict[\"visual_encoder.pos_embed\"] = pos_embed_reshaped\n\n        m_pos_embed_reshaped = interpolate_pos_embed(\n            state_dict[\"visual_encoder_m.pos_embed\"], self.visual_encoder_m\n        )\n        state_dict[\"visual_encoder_m.pos_embed\"] = m_pos_embed_reshaped\n\n        for key in list(state_dict.keys()):\n            if \"bert\" in key:\n                encoder_key = key.replace(\"bert.\", \"\")\n                state_dict[encoder_key] = state_dict[key]\n\n            # intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder)\n            if \"text_encoder\" in key:\n                if \"layer\" in key:\n                    encoder_keys = key.split(\".\")\n                    layer_num = int(encoder_keys[4])\n\n                    if layer_num < 6:\n                        del state_dict[key]\n                        continue\n                    else:\n                        decoder_layer_num = layer_num - 6\n                        encoder_keys[4] = str(decoder_layer_num)\n                        encoder_key = \".\".join(encoder_keys)\n                else:\n                    encoder_key = key\n                decoder_key = encoder_key.replace(\"text_encoder\", \"text_decoder\")\n                state_dict[decoder_key] = state_dict[key]\n\n                del state_dict[key]\n\n        for key in self.state_dict().keys():\n            if key in state_dict.keys():\n                if state_dict[key].shape != self.state_dict()[key].shape:\n                    del state_dict[key]\n\n        msg = self.load_state_dict(state_dict, strict=False)\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n        logging.info(f\"missing keys: {msg.missing_keys}\")\n\n        return msg\n"
  },
  {
    "path": "lavis/models/alpro_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.utils import is_url\nfrom lavis.models.base_model import BaseModel\nfrom transformers import BertTokenizer\n\n\nclass AlproBase(BaseModel):\n    @classmethod\n    def init_tokenizer(cls):\n        return BertTokenizer.from_pretrained(\"bert-base-uncased\")\n\n    def load_from_pretrained(self, url_or_filename, num_frames, num_patches):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        if \"model\" in checkpoint:\n            state_dict = checkpoint[\"model\"]\n        else:\n            state_dict = checkpoint\n\n        for key in list(state_dict.keys()):\n            if \"bert\" in key:\n                new_key = key.replace(\"bert.\", \"\")\n                state_dict[new_key] = state_dict[key]\n                del state_dict[key]\n\n        spatial_embed_key = \"visual_encoder.model.pos_embed\"\n        temporal_embed_key = \"visual_encoder.model.time_embed\"\n\n        ## Resizing spatial embeddings in case they don't match\n        if num_patches + 1 != state_dict[spatial_embed_key].size(1):\n            state_dict[spatial_embed_key] = resize_spatial_embedding(\n                state_dict, spatial_embed_key, num_patches\n            )\n        else:\n            logging.info(\n                \"The length of spatial position embedding matches. No need to resize.\"\n            )\n\n        ## Resizing time embeddings in case they don't match\n        if temporal_embed_key in state_dict and num_frames != state_dict[\n            temporal_embed_key\n        ].size(1):\n            state_dict[temporal_embed_key] = resize_temporal_embedding(\n                state_dict, temporal_embed_key, num_frames\n            )\n        else:\n            logging.info(\n                \"No temporal encoding found. Or the length of temporal position embedding matches. No need to resize.\"\n            )\n\n        msg = self.load_state_dict(state_dict, strict=False)\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n\n\ndef resize_spatial_embedding(state_dict, key, num_patches):\n    logging.info(\n        f\"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}\"\n    )\n\n    pos_embed = state_dict[key]\n\n    cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)\n    other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)\n\n    new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode=\"nearest\")\n    new_pos_embed = new_pos_embed.transpose(1, 2)\n    new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n\n    return new_pos_embed\n\n\ndef resize_temporal_embedding(state_dict, key, num_frames):\n    logging.info(\n        f\"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}\"\n    )\n\n    time_embed = state_dict[key].transpose(1, 2)\n    new_time_embed = F.interpolate(time_embed, size=(num_frames), mode=\"nearest\")\n\n    return new_time_embed.transpose(1, 2)\n"
  },
  {
    "path": "lavis/models/alpro_models/alpro_outputs.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    ModelOutput,\n)\n\n\n@dataclass\nclass AlproSimilarity(ModelOutput):\n    sim_v2t: torch.FloatTensor = None\n    sim_t2v: torch.FloatTensor = None\n\n    sim_v2t_targets: Optional[torch.FloatTensor] = None\n    sim_t2v_targets: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass AlproIntermediateOutput(ModelOutput):\n    # uni-modal features\n    video_embeds: torch.FloatTensor = None\n    text_embeds: Optional[torch.FloatTensor] = None\n\n    # intermediate outputs of multimodal encoder\n    encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n    encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n\n    vtm_logits: Optional[torch.FloatTensor] = None\n    vtm_labels: Optional[torch.LongTensor] = None\n\n\n@dataclass\nclass AlproOutput(ModelOutput):\n    # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.\n    sims: Optional[AlproSimilarity] = None\n\n    intermediate_output: AlproIntermediateOutput = None\n\n    loss: Optional[torch.FloatTensor] = None\n\n    loss_vtc: Optional[torch.FloatTensor] = None\n\n    loss_vtm: Optional[torch.FloatTensor] = None\n\n    loss_mlm: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass AlproOutputWithLogits(AlproOutput):\n    logits: torch.FloatTensor = None\n"
  },
  {
    "path": "lavis/models/alpro_models/alpro_qa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom warnings import warn\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.config import node_to_dict\nfrom lavis.common.registry import registry\nfrom lavis.models.alpro_models import AlproBase\nfrom lavis.models.alpro_models.alpro_outputs import (\n    AlproIntermediateOutput,\n    AlproOutputWithLogits,\n)\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.timesformer.vit import TimeSformer\nfrom torch import nn\n\n\n@registry.register_model(\"alpro_qa\")\nclass AlproQA(AlproBase):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"msrvtt\": \"configs/models/alpro_qa_msrvtt.yaml\",\n        \"msvd\": \"configs/models/alpro_qa_msvd.yaml\",\n    }\n\n    def __init__(\n        self, visual_encoder, text_encoder, hidden_size, num_classes, max_txt_len=40\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = visual_encoder\n\n        self.text_encoder = text_encoder\n\n        if num_classes > 0:\n            self.classifier = nn.Sequential(\n                nn.Linear(hidden_size, hidden_size * 2),\n                nn.ReLU(True),\n                nn.Linear(hidden_size * 2, num_classes),\n            )\n        else:\n            warn(f\"num_classes is 0. Initialized {type(self)} without classifier.\")\n\n        self.max_txt_len = max_txt_len\n\n    def forward(self, samples, is_train=True):\n        visual_inputs = samples[\"video\"]\n        question = samples[\"text_input\"]\n        targets = samples[\"answers\"]\n\n        # forward text\n        text = self.tokenizer(\n            question,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n\n        text_output = self.text_encoder.forward_text(\n            text,\n            token_type_ids=torch.zeros(\n                text.input_ids.shape, dtype=torch.long, device=self.device\n            ),\n        )\n        text_embeds = text_output.last_hidden_state\n\n        # forward visual\n        # timeSformer asks for (b, c, t, h, w) as input.\n        video_embeds = self.visual_encoder.forward_features(visual_inputs)\n        video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        # forward cross-encoder\n        attention_mask = torch.cat([text.attention_mask, video_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)\n\n        encoder_output = self.text_encoder(\n            encoder_embeds=embedding_output,\n            attention_mask=attention_mask,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        prediction = self.classifier(encoder_output.last_hidden_state[:, 0, :])\n        if is_train:\n            loss = F.cross_entropy(prediction, targets)\n            # return {\"loss\": loss}\n            return AlproOutputWithLogits(\n                loss=loss,\n                intermediate_output=AlproIntermediateOutput(\n                    video_embeds=video_embeds,\n                    text_embeds=text_embeds,\n                    encoder_output=encoder_output,\n                ),\n                logits=prediction,\n            )\n        else:\n            return {\"predictions\": prediction, \"targets\": targets}\n\n    def predict(self, samples):\n        output = self.forward(samples, is_train=False)\n        return output\n\n    @classmethod\n    def from_config(cls, cfg):\n        # vision encoder\n        visual_encoder_config = node_to_dict(cfg.timesformer)\n        visual_encoder = TimeSformer(**visual_encoder_config)\n\n        # text encoder\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        num_classes = cfg.get(\"num_classes\", -1)\n        hidden_size = cfg.get(\"hidden_size\", 768)\n\n        model = cls(\n            visual_encoder=visual_encoder,\n            text_encoder=text_encoder,\n            hidden_size=hidden_size,\n            num_classes=num_classes,\n        )\n\n        num_patches = (\n            visual_encoder_config[\"image_size\"] // visual_encoder_config[\"patch_size\"]\n        ) ** 2\n        num_frames = visual_encoder_config[\"n_frms\"]\n\n        model.load_checkpoint_from_config(\n            cfg, num_frames=num_frames, num_patches=num_patches\n        )\n\n        return model\n"
  },
  {
    "path": "lavis/models/alpro_models/alpro_retrieval.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport logging\nimport time\n\nimport lavis.common.dist_utils as dist_utils\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom lavis.common.config import node_to_dict\nfrom lavis.common.dist_utils import get_rank\nfrom lavis.common.logger import MetricLogger\nfrom lavis.common.registry import registry\nfrom lavis.models.alpro_models import AlproBase\nfrom lavis.models.alpro_models.alpro_outputs import AlproIntermediateOutput, AlproOutput\nfrom lavis.models.base_model import all_gather_with_grad\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.timesformer.vit import TimeSformer\nfrom torch import nn\n\n\n@registry.register_model(\"alpro_retrieval\")\nclass AlproRetrieval(AlproBase):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"msrvtt\": \"configs/models/alpro_retrieval_msrvtt.yaml\",\n        \"didemo\": \"configs/models/alpro_retrieval_didemo.yaml\",\n    }\n\n    def __init__(\n        self,\n        visual_encoder,\n        text_encoder,\n        vision_width=768,\n        text_width=768,\n        embed_dim=256,\n        max_txt_len=35,\n        temp=0.07,\n    ):\n        super().__init__()\n\n        self.temp = nn.Parameter(torch.ones([]) * temp)\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = visual_encoder\n        self.text_encoder = text_encoder\n\n        vision_width = vision_width\n        text_width = text_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n        self.max_txt_len = max_txt_len\n\n    def forward(self, samples):\n        with torch.no_grad():\n            self.temp.clamp_(0.001, 0.5)\n\n        visual_inputs = samples[\"video\"]\n        caption = samples[\"text_input\"]\n\n        b, t, c, h, w = visual_inputs.shape\n\n        # forward text\n        text = self.tokenizer(\n            caption,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n\n        text_output = self.text_encoder.forward_text(\n            text,\n            token_type_ids=torch.zeros(\n                text.input_ids.shape, dtype=torch.long, device=self.device\n            ),\n        )\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)\n\n        # forward visual\n        # timeSformer asks for (b, c, t, h, w) as input.\n        video_embeds = self.visual_encoder.forward_features(visual_inputs)\n        video_feat = F.normalize(self.vision_proj(video_embeds[:, 0, :]), dim=-1)\n        video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        # ========== (in-batch) ITC loss ==========\n        gathered_video_feats = all_gather_with_grad(video_feat)\n        gathered_text_feats = all_gather_with_grad(text_feat)\n\n        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp\n        sim_t2v = text_feat @ gathered_video_feats.t() / self.temp\n\n        sim_targets = torch.zeros_like(sim_v2t)\n\n        local_rank = get_rank()\n        b_start, b_end = b * local_rank, b * (local_rank + 1)\n        sim_targets[:, b_start:b_end] = torch.eye(b)\n\n        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_targets, dim=1).mean()\n        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_targets, dim=1).mean()\n\n        vtc_loss = (loss_v2t + loss_t2v) / 2\n\n        (\n            vtm_loss,\n            vtm_logits,\n            vtm_labels,\n            encoder_output,\n            encoder_output_neg,\n        ) = self.compute_vtm(\n            text_embeds=text_embeds,\n            text_atts=text.attention_mask,\n            image_embeds=video_embeds,\n            image_atts=video_atts,\n            sim_i2t=sim_v2t.clone(),  # for hard mining\n            sim_t2i=sim_t2v.clone(),  # for hard mining\n        )\n\n        loss = vtc_loss + vtm_loss\n\n        # return {\"loss\": loss}\n        return AlproOutput(\n            loss=loss,\n            loss_vtc=vtc_loss,\n            loss_vtm=vtm_loss,\n            intermediate_output=AlproIntermediateOutput(\n                video_embeds=video_embeds,\n                text_embeds=text_embeds,\n                encoder_output=encoder_output,\n                encoder_output_neg=encoder_output_neg,\n                vtm_logits=vtm_logits,\n                vtm_labels=vtm_labels,\n            ),\n        )\n\n    def compute_vtm(\n        self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i\n    ):\n        device = self.device\n\n        # ====== positive pairs =======\n        attention_mask = torch.cat([text_atts, image_atts], dim=1)\n        embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1)\n\n        encoder_outputs_pos = self.text_encoder(\n            encoder_embeds=embedding_output_pos,\n            attention_mask=attention_mask,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        # ====== negative pairs =======\n        bs = text_embeds.shape[0]\n\n        local_rank = get_rank()\n        b_start, b_end = bs * local_rank, bs * (local_rank + 1)\n\n        with torch.no_grad():\n            weights_v2t = sim_i2t[:, b_start:b_end]\n            weights_t2v = sim_t2i[:, b_start:b_end]\n\n            # never select self as negative\n            weights_v2t.fill_diagonal_(-np.Inf)\n            weights_t2v.fill_diagonal_(-np.Inf)\n\n            weights_v2t = F.softmax(weights_v2t, dim=1)\n            weights_t2v = F.softmax(weights_t2v, dim=1)\n\n        # select a negative image for each text\n        # FIXME to optimize using indexing operations\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2v[b], 1).item()\n            image_embeds_neg.append(image_embeds[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n        # select a negative text for each image\n        text_embeds_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_v2t[b], 1).item()\n            text_embeds_neg.append(text_embeds[neg_idx])\n            text_atts_neg.append(text_atts[neg_idx])\n\n        text_embeds_neg = torch.stack(text_embeds_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)\n        text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0)\n\n        video_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)\n        video_atts_all = torch.cat([image_atts, image_atts], dim=0)\n\n        attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)\n        embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)\n\n        # forward negative pairs via cross encoder\n        encoder_outputs_neg = self.text_encoder(\n            encoder_embeds=embedding_output_all,\n            attention_mask=attention_mask_all,\n            return_dict=True,\n            mode=\"fusion\",\n        )\n\n        vl_embeddings = torch.cat(\n            [\n                encoder_outputs_pos.last_hidden_state[:, 0, :],\n                encoder_outputs_neg.last_hidden_state[:, 0, :],\n            ],\n            dim=0,\n        )\n        vtm_logits = self.itm_head(vl_embeddings)\n\n        vtm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(device)\n        vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)\n\n        return (\n            vtm_loss,\n            vtm_logits,\n            vtm_labels,\n            encoder_outputs_pos,\n            encoder_outputs_neg,\n        )\n\n    def compute_sim_matrix(self, data_loader, task_cfg):\n        k_test = task_cfg.get(\"k_test\")\n\n        metric_logger = MetricLogger(delimiter=\"  \")\n        header = \"Evaluation:\"\n\n        logging.info(\"Computing features for evaluation...\")\n        start_time = time.time()\n\n        texts = data_loader.dataset.text\n        num_text = len(texts)\n        text_bs = 256\n        text_ids = []\n        text_embeds = []\n        text_feats = []\n        text_atts = []\n        for i in range(0, num_text, text_bs):\n            text = texts[i : min(num_text, i + text_bs)]\n            text_input = self.tokenizer(\n                text,\n                padding=\"max_length\",\n                truncation=True,\n                max_length=self.max_txt_len,\n                return_tensors=\"pt\",\n            ).to(self.device)\n            text_output = self.text_encoder.forward_text(\n                text_input,\n                token_type_ids=torch.zeros(\n                    text_input.input_ids.shape, dtype=torch.long, device=self.device\n                ),\n            )\n            text_feats.append(text_output.last_hidden_state.cpu())\n            text_embed = F.normalize(\n                self.text_proj(text_output.last_hidden_state[:, 0, :])\n            )\n            text_embeds.append(text_embed)\n            text_ids.append(text_input.input_ids)\n            text_atts.append(text_input.attention_mask)\n\n        text_embeds = torch.cat(text_embeds, dim=0)\n        text_ids = torch.cat(text_ids, dim=0)\n        text_atts = torch.cat(text_atts, dim=0)\n        text_feats = torch.cat(text_feats, dim=0)\n\n        video_feats = []\n        video_embeds = []\n        for samples in data_loader:\n            video = samples[\"video\"]\n\n            video = video.to(self.device)\n            video_feat = self.visual_encoder.forward_features(video)\n            video_embed = self.vision_proj(video_feat[:, 0, :])\n            video_embed = F.normalize(video_embed, dim=-1)\n\n            video_feats.append(video_feat.cpu())\n            video_embeds.append(video_embed)\n\n        video_feats = torch.cat(video_feats, dim=0)\n        video_embeds = torch.cat(video_embeds, dim=0)\n\n        sims_matrix = video_embeds @ text_embeds.t()\n        score_matrix_v2t = torch.full(\n            (len(data_loader.dataset.image), len(texts)), -100.0\n        ).to(self.device)\n\n        num_tasks = dist_utils.get_world_size()\n        rank = dist_utils.get_rank()\n        step = sims_matrix.size(0) // num_tasks + 1\n        start = rank * step\n        end = min(sims_matrix.size(0), start + step)\n\n        # video-to-text\n        for i, sims in enumerate(\n            metric_logger.log_every(sims_matrix[start:end], 50, header)\n        ):\n            topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n\n            video_feats_repeat = (\n                video_feats[start + i].repeat(k_test, 1, 1).to(self.device)\n            )\n            video_atts_repeat = torch.ones(\n                video_feats_repeat.size()[:-1], dtype=torch.long\n            ).to(self.device)\n\n            attention_mask = torch.cat([text_atts[topk_idx], video_atts_repeat], dim=1)\n            embedding_output = torch.cat(\n                [text_feats[topk_idx].to(self.device), video_feats_repeat], dim=1\n            )\n\n            output = self.text_encoder(\n                encoder_embeds=embedding_output,\n                attention_mask=attention_mask,\n                return_dict=True,\n                mode=\"fusion\",\n            )\n\n            score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1]\n            score_matrix_v2t[start + i, topk_idx] = score + topk_sim\n\n        # text-to-video\n        sims_matrix = sims_matrix.t()\n        score_matrix_t2v = torch.full(\n            (len(texts), len(data_loader.dataset.image)), -100.0\n        ).to(self.device)\n\n        step = sims_matrix.size(0) // num_tasks + 1\n        start = rank * step\n        end = min(sims_matrix.size(0), start + step)\n\n        for i, sims in enumerate(\n            metric_logger.log_every(sims_matrix[start:end], 50, header)\n        ):\n\n            topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n\n            text_feats_repeat = (\n                text_feats[start + i].repeat(k_test, 1, 1).to(self.device)\n            )\n            text_atts_repeat = text_atts[start + i].repeat(k_test, 1).to(self.device)\n\n            video_atts = torch.ones(\n                video_feats[topk_idx].size()[:-1], dtype=torch.long\n            ).to(self.device)\n\n            embedding_output = torch.cat(\n                [text_feats_repeat, video_feats[topk_idx].to(self.device)], dim=1\n            )\n            attention_mask = torch.cat([text_atts_repeat, video_atts], dim=1)\n\n            output = self.text_encoder(\n                encoder_embeds=embedding_output,\n                attention_mask=attention_mask,\n                return_dict=True,\n                mode=\"fusion\",\n            )\n\n            score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1]\n            score_matrix_t2v[start + i, topk_idx] = score + topk_sim\n\n        if dist_utils.is_dist_avail_and_initialized():\n            dist.barrier()\n            torch.distributed.all_reduce(\n                score_matrix_v2t, op=torch.distributed.ReduceOp.SUM\n            )\n            torch.distributed.all_reduce(\n                score_matrix_t2v, op=torch.distributed.ReduceOp.SUM\n            )\n\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        logging.info(\"Evaluation time {}\".format(total_time_str))\n\n        return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()\n\n    @classmethod\n    def from_config(cls, cfg):\n        # vision encoder\n        visual_encoder_config = node_to_dict(cfg.timesformer)\n        visual_encoder = TimeSformer(**visual_encoder_config)\n\n        # text encoder\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        max_txt_len = cfg.get(\"max_txt_len\", 35)\n\n        model = cls(\n            visual_encoder=visual_encoder,\n            text_encoder=text_encoder,\n            max_txt_len=max_txt_len,\n        )\n\n        num_patches = (\n            visual_encoder_config[\"image_size\"] // visual_encoder_config[\"patch_size\"]\n        ) ** 2\n        num_frames = visual_encoder_config[\"n_frms\"]\n\n        model.load_checkpoint_from_config(\n            cfg, num_frames=num_frames, num_patches=num_patches\n        )\n\n        return model\n"
  },
  {
    "path": "lavis/models/base_model.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom lavis.common.dist_utils import download_cached_file, is_dist_avail_and_initialized\nfrom lavis.common.utils import get_abs_path, is_url\nfrom omegaconf import OmegaConf\n\n\nclass BaseModel(nn.Module):\n    \"\"\"Base class for models.\"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    @property\n    def device(self):\n        return list(self.parameters())[0].device\n\n    def load_checkpoint(self, url_or_filename):\n        \"\"\"\n        Load from a finetuned checkpoint.\n\n        This should expect no mismatch in the model keys and the checkpoint keys.\n        \"\"\"\n\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        if \"model\" in checkpoint.keys():\n            state_dict = checkpoint[\"model\"]\n        else:\n            state_dict = checkpoint\n\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n\n    @classmethod\n    def from_pretrained(cls, model_type):\n        \"\"\"\n        Build a pretrained model from default configuration file, specified by model_type.\n\n        Args:\n            - model_type (str): model type, specifying architecture and checkpoints.\n\n        Returns:\n            - model (nn.Module): pretrained or finetuned model, depending on the configuration.\n        \"\"\"\n        model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model\n        model = cls.from_config(model_cfg)\n\n        return model\n\n    @classmethod\n    def default_config_path(cls, model_type):\n        assert (\n            model_type in cls.PRETRAINED_MODEL_CONFIG_DICT\n        ), \"Unknown model type {}\".format(model_type)\n        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])\n\n    def load_checkpoint_from_config(self, cfg, **kwargs):\n        \"\"\"\n        Load checkpoint as specified in the config file.\n\n        If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.\n        When loading the pretrained model, each task-specific architecture may define their\n        own load_from_pretrained() method.\n        \"\"\"\n        load_finetuned = cfg.get(\"load_finetuned\", True)\n        if load_finetuned:\n            finetune_path = cfg.get(\"finetuned\", None)\n            assert (\n                finetune_path is not None\n            ), \"Found load_finetuned is True, but finetune_path is None.\"\n            self.load_checkpoint(url_or_filename=finetune_path)\n        else:\n            # load pre-trained weights\n            pretrain_path = cfg.get(\"pretrained\", None)\n            assert \"Found load_finetuned is False, but pretrain_path is None.\"\n            self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)\n\n    def before_evaluation(self, **kwargs):\n        pass\n\n    def show_n_params(self, return_str=True):\n        tot = 0\n        for p in self.parameters():\n            w = 1\n            for x in p.shape:\n                w *= x\n            tot += w\n        if return_str:\n            if tot >= 1e6:\n                return \"{:.1f}M\".format(tot / 1e6)\n            else:\n                return \"{:.1f}K\".format(tot / 1e3)\n        else:\n            return tot\n\n\nclass BaseEncoder(nn.Module):\n    \"\"\"\n    Base class for primitive encoders, such as ViT, TimeSformer, etc.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward_features(self, samples, **kwargs):\n        raise NotImplementedError\n\n    @property\n    def device(self):\n        return list(self.parameters())[0].device\n\n\nclass SharedQueueMixin:\n    @torch.no_grad()\n    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):\n        # gather keys before updating queue\n        image_feats = concat_all_gather(image_feat)\n        text_feats = concat_all_gather(text_feat)\n\n        batch_size = image_feats.shape[0]\n\n        ptr = int(self.queue_ptr)\n        assert self.queue_size % batch_size == 0  # for simplicity\n\n        # replace the keys at ptr (dequeue and enqueue)\n        self.image_queue[:, ptr : ptr + batch_size] = image_feats.T\n        self.text_queue[:, ptr : ptr + batch_size] = text_feats.T\n\n        if idxs is not None:\n            idxs = concat_all_gather(idxs)\n            self.idx_queue[:, ptr : ptr + batch_size] = idxs.T\n\n        ptr = (ptr + batch_size) % self.queue_size  # move pointer\n        self.queue_ptr[0] = ptr\n\n\nclass MomentumDistilationMixin:\n    @torch.no_grad()\n    def copy_params(self):\n        for model_pair in self.model_pairs:\n            for param, param_m in zip(\n                model_pair[0].parameters(), model_pair[1].parameters()\n            ):\n                param_m.data.copy_(param.data)  # initialize\n                param_m.requires_grad = False  # not update by gradient\n\n    @torch.no_grad()\n    def _momentum_update(self):\n        for model_pair in self.model_pairs:\n            for param, param_m in zip(\n                model_pair[0].parameters(), model_pair[1].parameters()\n            ):\n                param_m.data = param_m.data * self.momentum + param.data * (\n                    1.0 - self.momentum\n                )\n\n\nclass GatherLayer(torch.autograd.Function):\n    \"\"\"\n    Gather tensors from all workers with support for backward propagation:\n    This implementation does not cut the gradients as torch.distributed.all_gather does.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x):\n        output = [\n            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())\n        ]\n        torch.distributed.all_gather(output, x)\n        return tuple(output)\n\n    @staticmethod\n    def backward(ctx, *grads):\n        all_gradients = torch.stack(grads)\n        torch.distributed.all_reduce(all_gradients)\n        return all_gradients[torch.distributed.get_rank()]\n\n\ndef all_gather_with_grad(tensors):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    Graph remains connected for backward grad computation.\n    \"\"\"\n    # Queue the gathered tensors\n    world_size = torch.distributed.get_world_size()\n    # There is no need for reduction in the single-proc case\n    if world_size == 1:\n        return tensors\n\n    # tensor_all = GatherLayer.apply(tensors)\n    tensor_all = GatherLayer.apply(tensors)\n\n    return torch.cat(tensor_all, dim=0)\n\n\n@torch.no_grad()\ndef concat_all_gather(tensor):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    *** Warning ***: torch.distributed.all_gather has no gradient.\n    \"\"\"\n    # if use distributed training\n    if not is_dist_avail_and_initialized():\n        return tensor\n\n    tensors_gather = [\n        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())\n    ]\n    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)\n\n    output = torch.cat(tensors_gather, dim=0)\n    return output\n\n\ndef tile(x, dim, n_tile):\n    init_dim = x.size(dim)\n    repeat_idx = [1] * x.dim()\n    repeat_idx[dim] = n_tile\n    x = x.repeat(*(repeat_idx))\n    order_index = torch.LongTensor(\n        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])\n    )\n    return torch.index_select(x, dim, order_index.to(x.device))\n"
  },
  {
    "path": "lavis/models/blip2_models/Qformer.py",
    "content": "\"\"\"\n * Copyright (c) 2023, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n * By Junnan Li\n * Based on huggingface code base\n * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert\n\"\"\"\n\nimport math\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Dict, Any\n\nimport torch\nfrom torch import Tensor, device, dtype, nn\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nimport torch.nn.functional as F\n\nfrom transformers.activations import ACT2FN\nfrom transformers.file_utils import (\n    ModelOutput,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    NextSentencePredictorOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.utils import logging\nfrom transformers.models.bert.configuration_bert import BertConfig\n\nlogger = logging.get_logger(__name__)\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word and position embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id\n        )\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size\n        )\n\n        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n        # any TensorFlow checkpoint file\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1))\n        )\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n\n        self.config = config\n\n    def forward(\n        self,\n        input_ids=None,\n        position_ids=None,\n        query_embeds=None,\n        past_key_values_length=0,\n    ):\n        if input_ids is not None:\n            seq_length = input_ids.size()[1]\n        else:\n            seq_length = 0\n\n        if position_ids is None:\n            position_ids = self.position_ids[\n                :, past_key_values_length : seq_length + past_key_values_length\n            ].clone()\n\n        if input_ids is not None:\n            embeddings = self.word_embeddings(input_ids)\n            if self.position_embedding_type == \"absolute\":\n                position_embeddings = self.position_embeddings(position_ids)\n                embeddings = embeddings + position_embeddings\n\n            if query_embeds is not None:\n                embeddings = torch.cat((query_embeds, embeddings), dim=1)\n        else:\n            embeddings = query_embeds\n\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, is_cross_attention):\n        super().__init__()\n        self.config = config\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(\n            config, \"embedding_size\"\n        ):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        if is_cross_attention:\n            self.key = nn.Linear(config.encoder_width, self.all_head_size)\n            self.value = nn.Linear(config.encoder_width, self.all_head_size)\n        else:\n            self.key = nn.Linear(config.hidden_size, self.all_head_size)\n            self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(\n                2 * config.max_position_embeddings - 1, self.attention_head_size\n            )\n        self.save_attention = False\n\n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n\n    def get_attn_gradients(self):\n        return self.attn_gradients\n\n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n\n    def get_attention_map(self):\n        return self.attention_map\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (\n            self.num_attention_heads,\n            self.attention_head_size,\n        )\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        mixed_query_layer = self.query(hidden_states)\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            seq_length = hidden_states.size()[1]\n            position_ids_l = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(-1, 1)\n            position_ids_r = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(\n                distance + self.max_position_embeddings - 1\n            )\n            positional_embedding = positional_embedding.to(\n                dtype=query_layer.dtype\n            )  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                relative_position_scores_key = torch.einsum(\n                    \"bhrd,lrd->bhlr\", key_layer, positional_embedding\n                )\n                attention_scores = (\n                    attention_scores\n                    + relative_position_scores_query\n                    + relative_position_scores_key\n                )\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        if is_cross_attention and self.save_attention:\n            self.save_attention_map(attention_probs)\n            attention_probs.register_hook(self.save_attn_gradients)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs_dropped = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs_dropped = attention_probs_dropped * head_mask\n\n        context_layer = torch.matmul(attention_probs_dropped, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (\n            (context_layer, attention_probs) if output_attentions else (context_layer,)\n        )\n\n        outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, is_cross_attention=False):\n        super().__init__()\n        self.self = BertSelfAttention(config, is_cross_attention)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads,\n            self.self.num_attention_heads,\n            self.self.attention_head_size,\n            self.pruned_heads,\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = (\n            self.self.attention_head_size * self.self.num_attention_heads\n        )\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n\n        outputs = (attention_output,) + self_outputs[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_num):\n        super().__init__()\n        self.config = config\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n        self.layer_num = layer_num\n        if (\n            self.config.add_cross_attention\n            and layer_num % self.config.cross_attention_freq == 0\n        ):\n            self.crossattention = BertAttention(\n                config, is_cross_attention=self.config.add_cross_attention\n            )\n            self.has_cross_attention = True\n        else:\n            self.has_cross_attention = False\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n        self.intermediate_query = BertIntermediate(config)\n        self.output_query = BertOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        query_length=0,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = (\n            past_key_value[:2] if past_key_value is not None else None\n        )\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n        outputs = self_attention_outputs[1:-1]\n\n        present_key_value = self_attention_outputs[-1]\n\n        if query_length > 0:\n            query_attention_output = attention_output[:, :query_length, :]\n\n            if self.has_cross_attention:\n                assert (\n                    encoder_hidden_states is not None\n                ), \"encoder_hidden_states must be given for cross-attention layers\"\n                cross_attention_outputs = self.crossattention(\n                    query_attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    output_attentions=output_attentions,\n                )\n                query_attention_output = cross_attention_outputs[0]\n                outputs = (\n                    outputs + cross_attention_outputs[1:-1]\n                )  # add cross attentions if we output attention weights\n\n            layer_output = apply_chunking_to_forward(\n                self.feed_forward_chunk_query,\n                self.chunk_size_feed_forward,\n                self.seq_len_dim,\n                query_attention_output,\n            )\n            if attention_output.shape[1] > query_length:\n                layer_output_text = apply_chunking_to_forward(\n                    self.feed_forward_chunk,\n                    self.chunk_size_feed_forward,\n                    self.seq_len_dim,\n                    attention_output[:, query_length:, :],\n                )\n                layer_output = torch.cat([layer_output, layer_output_text], dim=1)\n        else:\n            layer_output = apply_chunking_to_forward(\n                self.feed_forward_chunk,\n                self.chunk_size_feed_forward,\n                self.seq_len_dim,\n                attention_output,\n            )\n        outputs = (layer_output,) + outputs\n\n        outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n    def feed_forward_chunk_query(self, attention_output):\n        intermediate_output = self.intermediate_query(attention_output)\n        layer_output = self.output_query(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList(\n            [BertLayer(config, i) for i in range(config.num_hidden_layers)]\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        query_length=0,\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = (\n            () if output_attentions and self.config.add_cross_attention else None\n        )\n\n        next_decoder_cache = () if use_cache else None\n\n        for i in range(self.config.num_hidden_layers):\n            layer_module = self.layer[i]\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if getattr(self.config, \"gradient_checkpointing\", False) and self.training:\n\n                if use_cache:\n                    logger.warn(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(\n                            *inputs, past_key_value, output_attentions, query_length\n                        )\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    query_length,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertOnlyMLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n\n    def forward(self, sequence_output):\n        prediction_scores = self.predictions(sequence_output)\n        return prediction_scores\n\n\nclass BertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = BertConfig\n    base_model_prefix = \"bert\"\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n\nclass BertModel(BertPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config, add_pooling_layer=False):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n\n        self.encoder = BertEncoder(config)\n\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def get_extended_attention_mask(\n        self,\n        attention_mask: Tensor,\n        input_shape: Tuple[int],\n        device: device,\n        is_decoder: bool,\n        has_query: bool = False,\n    ) -> Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple[int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if attention_mask.dim() == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif attention_mask.dim() == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if is_decoder:\n                batch_size, seq_length = input_shape\n\n                seq_ids = torch.arange(seq_length, device=device)\n                causal_mask = (\n                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)\n                    <= seq_ids[None, :, None]\n                )\n\n                # add a prefix ones mask to the causal mask\n                # causal and attention masks must have same type with pytorch version < 1.3\n                causal_mask = causal_mask.to(attention_mask.dtype)\n\n                if causal_mask.shape[1] < attention_mask.shape[1]:\n                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                    if has_query:  # UniLM style attention mask\n                        causal_mask = torch.cat(\n                            [\n                                torch.zeros(\n                                    (batch_size, prefix_seq_len, seq_length),\n                                    device=device,\n                                    dtype=causal_mask.dtype,\n                                ),\n                                causal_mask,\n                            ],\n                            axis=1,\n                        )\n                    causal_mask = torch.cat(\n                        [\n                            torch.ones(\n                                (batch_size, causal_mask.shape[1], prefix_seq_len),\n                                device=device,\n                                dtype=causal_mask.dtype,\n                            ),\n                            causal_mask,\n                        ],\n                        axis=-1,\n                    )\n                extended_attention_mask = (\n                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n                )\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                \"Wrong shape for input_ids (shape {}) or attention_mask (shape {})\".format(\n                    input_shape, attention_mask.shape\n                )\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.to(\n            dtype=self.dtype\n        )  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if input_ids is None:\n            assert (\n                query_embeds is not None\n            ), \"You have to specify query_embeds when input_ids is None\"\n\n        # past_key_values_length\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] - self.config.query_length\n            if past_key_values is not None\n            else 0\n        )\n\n        query_length = query_embeds.shape[1] if query_embeds is not None else 0\n\n        embedding_output = self.embeddings(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            query_embeds=query_embeds,\n            past_key_values_length=past_key_values_length,\n        )\n\n        input_shape = embedding_output.size()[:-1]\n        batch_size, seq_length = input_shape\n        device = embedding_output.device\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                ((batch_size, seq_length + past_key_values_length)), device=device\n            )\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if is_decoder:\n            extended_attention_mask = self.get_extended_attention_mask(\n                attention_mask,\n                input_ids.shape,\n                device,\n                is_decoder,\n                has_query=(query_embeds is not None),\n            )\n        else:\n            extended_attention_mask = self.get_extended_attention_mask(\n                attention_mask, input_shape, device, is_decoder\n            )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if encoder_hidden_states is not None:\n            if type(encoder_hidden_states) == list:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[\n                    0\n                ].size()\n            else:\n                (\n                    encoder_batch_size,\n                    encoder_sequence_length,\n                    _,\n                ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n\n            if type(encoder_attention_mask) == list:\n                encoder_extended_attention_mask = [\n                    self.invert_attention_mask(mask) for mask in encoder_attention_mask\n                ]\n            elif encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n            else:\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            query_length=query_length,\n        )\n        sequence_output = encoder_outputs[0]\n        pooled_output = (\n            self.pooler(sequence_output) if self.pooler is not None else None\n        )\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n\n\nclass BertLMHeadModel(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        past_key_values=None,\n        use_cache=True,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        return_logits=False,\n        is_decoder=True,\n        reduction=\"mean\",\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are\n            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n            >>> config = BertConfig.from_pretrained(\"bert-base-cased\")\n            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)\n            >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n            >>> outputs = model(**inputs)\n            >>> prediction_logits = outputs.logits\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n        if labels is not None:\n            use_cache = False\n        if past_key_values is not None:\n            query_embeds = None\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            query_embeds=query_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n        )\n\n        sequence_output = outputs[0]\n        if query_embeds is not None:\n            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]\n\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores[:, :-1, :].contiguous()\n\n        lm_loss = None\n        if labels is not None:\n            # we are doing next-token prediction; shift prediction scores and input ids by one\n            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()\n            labels = labels[:, 1:].contiguous()\n            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)\n            lm_loss = loss_fct(\n                shifted_prediction_scores.view(-1, self.config.vocab_size),\n                labels.view(-1),\n            )\n            if reduction == \"none\":\n                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return ((lm_loss,) + output) if lm_loss is not None else output\n\n        return CausalLMOutputWithCrossAttentions(\n            loss=lm_loss,\n            logits=prediction_scores,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            cross_attentions=outputs.cross_attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs\n    ):\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            attention_mask = input_ids.new_ones(input_ids.shape)\n        query_mask = input_ids.new_ones(query_embeds.shape[:-1])\n        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)\n\n        # cut decoder_input_ids if past is used\n        if past is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"input_ids\": input_ids,\n            \"query_embeds\": query_embeds,\n            \"attention_mask\": attention_mask,\n            \"past_key_values\": past,\n            \"encoder_hidden_states\": model_kwargs.get(\"encoder_hidden_states\", None),\n            \"encoder_attention_mask\": model_kwargs.get(\"encoder_attention_mask\", None),\n            \"is_decoder\": True,\n        }\n\n    def _reorder_cache(self, past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\nclass BertForMaskedLM(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        return_logits=False,\n        is_decoder=False,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            query_embeds=query_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n        )\n\n        if query_embeds is not None:\n            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores\n\n        masked_lm_loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()  # -100 index = padding token\n            masked_lm_loss = loss_fct(\n                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)\n            )\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return (\n                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n            )\n\n        return MaskedLMOutput(\n            loss=masked_lm_loss,\n            logits=prediction_scores,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "lavis/models/blip2_models/__init__.py",
    "content": ""
  },
  {
    "path": "lavis/models/blip2_models/blip2.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\nimport os\nimport time\nimport datetime\n\nimport torch\nimport torch.nn as nn\nimport torch.distributed as dist\nimport torch.nn.functional as F\n\nimport lavis.common.dist_utils as dist_utils\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.utils import is_url\nfrom lavis.common.logger import MetricLogger\nfrom lavis.models.base_model import BaseModel\nfrom lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel\nfrom lavis.models.eva_vit import create_eva_vit_g\nfrom transformers import BertTokenizer\n\n\nclass Blip2Base(BaseModel):\n    @classmethod\n    def init_tokenizer(cls):\n        tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n        tokenizer.add_special_tokens({\"bos_token\": \"[DEC]\"})\n        return tokenizer\n\n    @classmethod\n    def init_Qformer(cls, num_query_token, vision_width):\n        encoder_config = BertConfig.from_pretrained(\"bert-base-uncased\")\n        encoder_config.encoder_width = vision_width\n        # insert cross-attention layer every other block\n        encoder_config.add_cross_attention = True\n        encoder_config.cross_attention_freq = 2\n        encoder_config.query_length = num_query_token\n        Qformer = BertLMHeadModel.from_pretrained(\n            \"bert-base-uncased\", config=encoder_config\n        )                 \n        query_tokens = nn.Parameter(\n            torch.zeros(1, num_query_token, encoder_config.hidden_size)\n        )\n        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)\n        return Qformer, query_tokens\n    \n    @classmethod\n    def init_TemporalQFormer(cls, num_of_frame):\n        encoder_config = BertConfig.from_pretrained(\"bert-base-uncased\")\n        encoder_config.query_length = num_of_frame\n        Qformer = BertLMHeadModel.from_pretrained(\n        \"bert-base-uncased\", config=encoder_config\n        )                 \n        query_tokens = nn.Parameter(\n            torch.zeros(1, num_of_frame, 1, encoder_config.hidden_size)\n        )\n        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)\n        return Qformer, query_tokens\n\n    @classmethod\n    def init_vision_encoder(\n        cls, img_size, drop_path_rate, use_grad_checkpoint, precision\n    ):\n        visual_encoder = create_eva_vit_g(\n            img_size, drop_path_rate, use_grad_checkpoint, precision\n        )\n        ln_vision = LayerNorm(visual_encoder.num_features)\n        return visual_encoder, ln_vision\n\n    @classmethod\n    def init_vision_encoder_sevila(\n        cls, img_size, drop_path_rate, use_grad_checkpoint, precision\n    ):\n        visual_encoder = create_eva_vit_g(\n            img_size, drop_path_rate, use_grad_checkpoint, precision\n        )\n        ln_vision = LayerNorm(visual_encoder.num_features)\n        ln_vision2 = LayerNorm(visual_encoder.num_features) \n        return visual_encoder, ln_vision, ln_vision2\n\n    def load_from_pretrained(self, url_or_filename):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n        #print('state_dict',state_dict.keys())\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n    \n    def load_qformer_loc(self):\n        url_or_filename = '/nas-hdd/shoubin/pretrained_model/hub/checkpoints/qformer_loc.pth'\n        checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        state_dict = checkpoint[\"model\"]\n        msg = self.load_state_dict(state_dict, strict=False)\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\ndef compute_sim_matrix(model, data_loader, **kwargs):\n    k_test = kwargs.pop(\"k_test\")\n\n    metric_logger = MetricLogger(delimiter=\"  \")\n    header = \"Evaluation:\"\n\n    logging.info(\"Computing features for evaluation...\")\n    start_time = time.time()\n\n    texts = data_loader.dataset.text\n    num_text = len(texts)\n    text_bs = 256\n    text_ids = []\n    text_embeds = []\n    text_atts = []\n    for i in range(0, num_text, text_bs):\n        text = texts[i : min(num_text, i + text_bs)]\n        text_input = model.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=35,\n            return_tensors=\"pt\",\n        ).to(model.device)\n        text_feat = model.forward_text(text_input)\n        text_embed = F.normalize(model.text_proj(text_feat))\n        text_embeds.append(text_embed)\n        text_ids.append(text_input.input_ids)\n        text_atts.append(text_input.attention_mask)\n\n    text_embeds = torch.cat(text_embeds, dim=0)\n    text_ids = torch.cat(text_ids, dim=0)\n    text_atts = torch.cat(text_atts, dim=0)\n\n    vit_feats = []\n    image_embeds = []\n    for samples in data_loader:\n        image = samples[\"image\"]\n\n        image = image.to(model.device)\n        image_feat, vit_feat = model.forward_image(image)\n        image_embed = model.vision_proj(image_feat)\n        image_embed = F.normalize(image_embed, dim=-1)\n\n        vit_feats.append(vit_feat.cpu())\n        image_embeds.append(image_embed)\n\n    vit_feats = torch.cat(vit_feats, dim=0)\n    image_embeds = torch.cat(image_embeds, dim=0)\n\n    sims_matrix = []\n    for image_embed in image_embeds:\n        sim_q2t = image_embed @ text_embeds.t()\n        sim_i2t, _ = sim_q2t.max(0)\n        sims_matrix.append(sim_i2t)\n    sims_matrix = torch.stack(sims_matrix, dim=0)\n\n    score_matrix_i2t = torch.full(\n        (len(data_loader.dataset.image), len(texts)), -100.0\n    ).to(model.device)\n\n    num_tasks = dist_utils.get_world_size()\n    rank = dist_utils.get_rank()\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n        image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)\n        score = model.compute_itm(\n            image_inputs=image_inputs,\n            text_ids=text_ids[topk_idx],\n            text_atts=text_atts[topk_idx],\n        ).float()\n        score_matrix_i2t[start + i, topk_idx] = score + topk_sim\n\n    sims_matrix = sims_matrix.t()\n    score_matrix_t2i = torch.full(\n        (len(texts), len(data_loader.dataset.image)), -100.0\n    ).to(model.device)\n\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n        image_inputs = vit_feats[topk_idx.cpu()].to(model.device)\n        score = model.compute_itm(\n            image_inputs=image_inputs,\n            text_ids=text_ids[start + i].repeat(k_test, 1),\n            text_atts=text_atts[start + i].repeat(k_test, 1),\n        ).float()\n        score_matrix_t2i[start + i, topk_idx] = score + topk_sim\n\n    if dist_utils.is_dist_avail_and_initialized():\n        dist.barrier()\n        torch.distributed.all_reduce(\n            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM\n        )\n        torch.distributed.all_reduce(\n            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM\n        )\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logging.info(\"Evaluation time {}\".format(total_time_str))\n\n    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()\n"
  },
  {
    "path": "lavis/models/blip2_models/blip2_fmr.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\n\nimport copy\nimport torch\nimport torch.nn as nn\nfrom torch.cuda.amp import autocast as autocast\nfrom transformers import T5TokenizerFast, BertTokenizer\n\nfrom lavis.common.registry import registry\nfrom lavis.models.blip2_models.blip2 import Blip2Base, disabled_train\nfrom lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration\n\n@registry.register_model(\"blip2_fmr\") # frame-level moment retrieval\nclass Blip2FMR(Blip2Base):\n    \"\"\"\n    BLIP2 T5 model.\n    Supported model types:\n        - pretrain_flant5xl: pretrained model with FlanT5-XL\n        - pretrain_flant5xxl: pretrained model with FlanT5-XXL\n        - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2_t5\", \"pretrain_flant5xl\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"pretrain_flant5xl\": \"configs/models/blip2/blip2_pretrain_flant5xl.yaml\",\n        \"pretrain_flant5xxl\": \"configs/models/blip2/blip2_pretrain_flant5xxl.yaml\",\n        \"caption_coco_flant5xl\": \"configs/models/blip2/blip2_caption_flant5xl.yaml\",\n    }\n\n    def __init__( self, img_size=224, drop_path_rate=0,\n        use_grad_checkpoint=False, vit_precision=\"fp16\", freeze_vit=True,\n        num_query_token=32, t5_model=\"google/flan-t5-xl\", prompt=\"\",\n        max_txt_len=32, frame_num=8, answer_num=5, apply_lemmatizer=False, task='qa'):\n        \"\"\"\n        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.\n        \"\"\"\n        super().__init__()\n        \n        self.task = task\n        \n        # vision backbone\n        self.visual_encoder, self.ln_vision_loc = self.init_vision_encoder(\n        img_size, drop_path_rate, use_grad_checkpoint, vit_precision)\n        # Freeze ViT\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                param.requires_grad = False         \n            self.visual_encoder = self.visual_encoder.eval()\n            self.visual_encoder.train = disabled_train\n            logging.info(\"freeze vision encoder\")\n            \n        # text backbone\n        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)\n        t5_config = T5Config.from_pretrained(t5_model)\n        t5_config.dense_act_fn = \"gelu\"\n        self.t5_model = T5ForConditionalGeneration.from_pretrained(\n        t5_model, config=t5_config)\n        # Freeze T5\n        for name, param in self.t5_model.named_parameters():\n            param.requires_grad = False\n            param.data = param.data.bfloat16() \n        \n        # Q-Former for Frame Localization\n        self.Qformer_loc, self.query_tokens_loc = self.init_Qformer(\n        num_query_token, self.visual_encoder.num_features)\n\n        self.Qformer_loc.cls = None\n        self.Qformer_loc.bert.embeddings.word_embeddings = None\n        self.Qformer_loc.bert.embeddings.position_embeddings = None\n        for layer in self.Qformer_loc.bert.encoder.layer:\n            layer.output = None\n            layer.intermediate = None\n        self.t5_proj_loc = nn.Linear(\n        self.Qformer_loc.config.hidden_size, self.t5_model.config.hidden_size\n        )\n            \n        self.max_txt_len = 77\n        #self.prompt = prompt\n        answer_id = [71, 272, 205, 309, 262] # A B C D E\n        self.answer_id = answer_id[:answer_num]\n        # self.answer_id = [71, 272]\n        self.yes_id, self.no_id = 4273, 150\n        \n        self._apply_lemmatizer = apply_lemmatizer\n        self._lemmatizer = None\n        \n        self.frame_num = frame_num\n        self.ANS_MAP = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}\n        self.frame_prefix = ['Frame: ']\n            \n    def forward(self, samples):\n\n        image = samples[\"video\"]\n        text_input = samples['loc_input'] # query + options + Prompt\n        bs_answer = samples['qa_output'] # yes or no\n        flat_answer = []\n        for answer in bs_answer:\n            answer = answer.split('_')\n            for a in answer:\n                flat_answer.append(a)\n        \n        b, t, c, w, h = image.shape \n        image = image.reshape(-1, c, w, h)\n        image_embeds = self.ln_vision_loc(self.visual_encoder(image)) # bt, n, c\n        _, n, _ = image_embeds.shape\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n        \n        #pass\n        query_tokens = self.query_tokens_loc.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer_loc.bert(\n            query_embeds=query_tokens, encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts, return_dict=True)\n        inputs_t5 = self.t5_proj_loc(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) \n\n        with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n            # Frame Prefix\n            frame_prefix = self.t5_tokenizer(\n                self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",\n            ).to(image.device) # \n            # print('frame_prefix 1', frame_prefix.input_ids.shape) 8, 4\n            frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n            frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n            # Question, Options input\n            input_tokens = self.t5_tokenizer(\n                text_input, padding=\"longest\", truncation=True,\n                max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n            input_ids = torch.repeat_interleave(input_tokens.input_ids, t, 0)\n            input_attention_mask = torch.repeat_interleave(input_tokens.attention_mask, t, 0)\n\n            # Output target\n            output_tokens = self.t5_tokenizer(\n                flat_answer, padding=\"longest\", truncation=True,\n                max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n            targets = output_tokens.input_ids.masked_fill(\n                output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100)\n            output_tokens_mask = output_tokens.attention_mask #torch.repeat_interleave(output_tokens.attention_mask, t, dim=0)\n            #targets = torch.repeat_interleave(targets, t, dim=0)\n            # input for QA\n            frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_ids)\n            inputs_embeds = torch.cat([frame_predix_embed, inputs_t5, inputs_embeds], dim=1)\n            encoder_atts = torch.cat([frame_prefix_mask, atts_t5, input_attention_mask], dim=1)\n\n            outputs = self.t5_model(\n                inputs_embeds=inputs_embeds, attention_mask=encoder_atts,\n                decoder_attention_mask=output_tokens_mask, return_dict=True, labels=targets)\n            loss = outputs.loss\n                \n        return {\"loss\": loss}\n        \n    @torch.no_grad()\n    def generate(self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=5, max_length=30,\n        min_length=1, top_p=0.9,\n        repetition_penalty=1.0, length_penalty=1.0,\n        num_captions=1, temperature=1,):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        out = {}\n        image, qid = samples[\"video\"], samples['question_id']\n        text_input, bs_answer = samples['loc_input'], samples['qa_output']  # Q + Options + Prompt: Choose an answer from options based on the frame.\n        # print('text_input', text_input)\n        flat_answer = []\n        # print('bs_answer', bs_answer)\n        for answer in bs_answer:\n            answer = answer.split('_')\n            for a in answer:\n                flat_answer.append(a)\n        # print('flat_answer', flat_answer)\n        \n        b, t, c, w, h = image.shape       \n        image = image.reshape(-1, c, w, h)\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision_loc(self.visual_encoder(image)) # bt, n, c\n        _, n, _ = image_embeds.shape\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n        \n        query_tokens = self.query_tokens_loc.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer_loc.bert(\n            query_embeds=query_tokens, encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts, return_dict=True)\n        inputs_t5 = self.t5_proj_loc(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) \n            \n        with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n            \n            frame_prefix = self.t5_tokenizer(\n                self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",\n                ).to(image.device) # \n            #print('frame_prefix 1', frame_prefix.input_ids.shape) 8, 4\n            frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n            frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n            # Question, Options input\n            input_tokens = self.t5_tokenizer(\n                text_input, padding=\"longest\", truncation=True,\n                max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n            input_ids = torch.repeat_interleave(input_tokens.input_ids, t, 0)\n            input_attention_mask = torch.repeat_interleave(input_tokens.attention_mask, t, 0)\n\n            frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_ids)\n            inputs_embeds = torch.cat([frame_predix_embed, inputs_t5, inputs_embeds], dim=1)\n            encoder_atts = torch.cat([frame_prefix_mask, atts_t5, input_attention_mask], dim=1)\n\n            outputs = self.t5_model.generate(\n                inputs_embeds=inputs_embeds, attention_mask=encoder_atts,\n                do_sample=use_nucleus_sampling, top_p=top_p,\n                temperature=temperature, num_beams=1,\n                max_new_tokens=max_length, min_length=min_length,\n                repetition_penalty=repetition_penalty, length_penalty=length_penalty,\n                num_return_sequences=num_captions, return_dict_in_generate=True,\n                output_hidden_states=True, output_scores=True)\n                # print('answer', answer)\n            pred_logits = outputs.scores[0] #outputs_embed_qa.logits.detach()\n            pred_logits = pred_logits[:, [self.no_id, self.yes_id]] # b, 5\n            pred_yes_score = pred_logits[:, 1].cpu().tolist() \n            pred_ans = torch.argmax(pred_logits, dim=-1).cpu().tolist()\n                     \n        out['answer'] = flat_answer\n        multiframe_qid = []\n        for q in qid:\n            for i in range(t):\n                multiframe_qid.append(q)\n                \n        out['qid'] = multiframe_qid\n        out['yes_score'] = pred_yes_score\n        out['pred_ans'] = pred_ans \n        \n        return out\n\n    def predict_answers(\n        self,\n        samples,\n        num_beams=5,\n        inference_method=\"generate\",\n        max_len=10,\n        min_len=1,\n        num_ans_candidates=128,\n        answer_list=None,\n        prompt=\"\",\n        length_penalty=-1,\n        **kwargs\n    ):\n        image = samples[\"image\"]\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision_loc(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_t5 = self.t5_proj(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n        if prompt:\n            text_input = [prompt.format(question) for question in samples[\"text_input\"]]\n        else:\n            text_input = samples[\"text_input\"]\n\n        input_tokens = self.t5_tokenizer(\n            text_input, padding=\"longest\", return_tensors=\"pt\"\n        ).to(image.device)\n\n        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n\n        device_type = \"cuda\" if \"cuda\" in str(self.device) else \"cpu\"\n        with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n\n            outputs = self.t5_model.generate(\n                inputs_embeds=inputs_embeds,\n                attention_mask=encoder_atts,\n                do_sample=False,\n                num_beams=num_beams,\n                max_new_tokens=max_len,\n                min_length=min_len,\n                length_penalty=length_penalty,\n            )\n            output_text = self.t5_tokenizer.batch_decode(\n                outputs, skip_special_tokens=True\n            )\n\n        if self._apply_lemmatizer:\n            output_text = self._lemmatize(output_text)\n\n        return output_text\n\n    def _lemmatize(self, answers):\n        def apply(answer):\n            doc = self.lemmatizer(answer)\n\n            words = []\n            for token in doc:\n                if token.pos_ in [\"NOUN\", \"VERB\"]:\n                    words.append(token.lemma_)\n                else:\n                    words.append(token.text)\n            answer = \" \".join(words)\n\n            return answer\n\n        return [apply(answer) for answer in answers]\n\n    @property\n    def lemmatizer(self):\n        if self._lemmatizer is None:\n            try:\n                import spacy\n\n                self._lemmatizer = spacy.load(\"en_core_web_sm\")\n            except ImportError:\n                logging.error(\n                    \"\"\"\n                    Please install spacy and en_core_web_sm model to apply lemmatization.\n                    python -m spacy download en_core_web_sm\n                    OR\n                    import spacy.cli\n                    spacy.cli.download(\"en_core_web_sm\")\n                    \"\"\"\n                )\n                exit(1)\n\n        return self._lemmatizer\n\n    @classmethod\n    def from_config(cls, cfg):\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n        t5_model = cfg.get(\"t5_model\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n        frame_num = cfg.get(\"frame_num\", 8)\n        answer_num = cfg.get(\"answer_num\", 5) \n        apply_lemmatizer = cfg.get(\"apply_lemmatizer\", False)\n        task = cfg.get(\"task\", 'train_loc_freeze_qa')\n\n        model = cls(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            t5_model=t5_model,\n            prompt=prompt,\n            max_txt_len=max_txt_len,\n            apply_lemmatizer=apply_lemmatizer,\n            frame_num=frame_num,\n            answer_num=answer_num,\n            task=task,\n        )\n        model.load_checkpoint_from_config(cfg)\n        # if 'pretrain_loc' in task:\n        # model.load_qformer_loc()\n\n        return model"
  },
  {
    "path": "lavis/models/blip2_models/blip2_image_text_matching.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.blip2_models.blip2_qformer import Blip2Qformer\n\n\n@registry.register_model(\"blip2_image_text_matching\")\nclass Blip2ITM(Blip2Qformer):\n    \"\"\"\n    BLIP Image-Text Matching (ITM) model.\n    Supported model types:\n        - pretrained: pretrained model\n        - coco: fintuned model on coco\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2_image_text_matching\", \"pretrained\")\n        >>> model = load_model(\"blip2_image_text_matching\", \"coco\")\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        drop_path_rate=0,\n        use_grad_checkpoint=False,\n        vit_precision=\"fp16\",\n        freeze_vit=True,\n        num_query_token=32,\n        embed_dim=256,\n        max_txt_len=32,\n    ):\n        super().__init__(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            embed_dim=embed_dim,\n            max_txt_len=max_txt_len,\n        )\n\n    def forward(self, samples, match_head=\"itm\"):\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        text = self.tokenizer(\n            caption,\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(image.device)\n\n        if match_head == \"itm\":\n            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(\n                image.device\n            )\n            attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)\n            output_itm = self.Qformer.bert(\n                text.input_ids,\n                query_embeds=query_tokens,\n                attention_mask=attention_mask,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n            itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]\n            itm_logit = self.itm_head(itm_embeddings)\n            itm_logit = itm_logit.mean(dim=1)\n\n            return itm_logit\n\n        elif match_head == \"itc\":\n            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n\n            query_output = self.Qformer.bert(\n                query_embeds=query_tokens,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n            image_feats = F.normalize(\n                self.vision_proj(query_output.last_hidden_state), dim=-1\n            )\n\n            text_output = self.Qformer.bert(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n            )\n            text_feat = F.normalize(\n                self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1\n            )\n\n            sims = torch.bmm(image_feats, text_feat.unsqueeze(-1))\n            sim, _ = torch.max(sims, dim=1)\n\n            return sim\n"
  },
  {
    "path": "lavis/models/blip2_models/blip2_opt.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\n\nfrom lavis.common.registry import registry\nfrom lavis.models.blip2_models.blip2 import Blip2Base, disabled_train\nfrom lavis.models.blip2_models.modeling_opt import OPTForCausalLM, OPTConfig\nfrom transformers import AutoTokenizer\n\n\n@registry.register_model(\"blip2_opt\")\nclass Blip2OPT(Blip2Base):\n    \"\"\"\n    BLIP2 OPT model.\n    Supported model types:\n        - pretrained_opt2.7b: pretrained model with OPT2.7b\n        - pretrained_opt6.7b: pretrained model with OPT6.7b\n        - caption_coco_opt2.7b: fintuned image captioning model with OPT2.7b\n        - caption_coco_opt6.7b: fintuned image captioning model with OPT6.7b\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2_opt\", \"caption_coco_opt2.7b\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"pretrain_opt2.7b\": \"configs/models/blip2/blip2_pretrain_opt2.7b.yaml\",\n        \"pretrain_opt6.7b\": \"configs/models/blip2/blip2_pretrain_opt6.7b.yaml\",\n        \"caption_coco_opt2.7b\": \"configs/models/blip2/blip2_caption_opt2.7b.yaml\",\n        \"caption_coco_opt6.7b\": \"configs/models/blip2/blip2_caption_opt6.7b.yaml\",\n    }\n\n    def __init__(\n        self,\n        img_size=224,\n        drop_path_rate=0,\n        use_grad_checkpoint=False,\n        vit_precision=\"fp16\",\n        freeze_vit=True,\n        num_query_token=32,\n        opt_model=\"facebook/opt-2.7b\",\n        prompt=\"\",\n        max_txt_len=32,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder, self.ln_vision = self.init_vision_encoder(\n            img_size, drop_path_rate, use_grad_checkpoint, vit_precision\n        )\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                param.requires_grad = False               \n            self.visual_encoder = self.visual_encoder.eval()\n            self.visual_encoder.train = disabled_train\n            logging.info(\"freeze vision encoder\")\n\n        self.Qformer, self.query_tokens = self.init_Qformer(\n            num_query_token, self.visual_encoder.num_features\n        )\n        self.Qformer.cls = None\n        self.Qformer.bert.embeddings.word_embeddings = None\n        self.Qformer.bert.embeddings.position_embeddings = None\n        for layer in self.Qformer.bert.encoder.layer:\n            layer.output = None\n            layer.intermediate = None\n\n        self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model, use_fast=False)\n        self.opt_model = OPTForCausalLM.from_pretrained(\n            opt_model, torch_dtype=torch.float16\n        )\n        for name, param in self.opt_model.named_parameters():\n            param.requires_grad = False\n        self.eos_token_id = self.opt_tokenizer(\n            \"\\n\", add_special_tokens=False\n        ).input_ids[0]\n\n        self.opt_proj = nn.Linear(\n            self.Qformer.config.hidden_size, self.opt_model.config.hidden_size\n        )\n\n        self.max_txt_len = max_txt_len\n        self.prompt = prompt\n        prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors=\"pt\")\n        self.prompt_length = prompt_tokens.attention_mask.sum(1)\n\n    def forward(self, samples):\n        image = samples[\"image\"]\n        image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_opt = self.opt_proj(query_output.last_hidden_state)\n        atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)\n\n        self.opt_tokenizer.padding_side = \"right\"\n\n        text = [t + \"\\n\" for t in samples[\"text_input\"]]\n\n        opt_tokens = self.opt_tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n        ).to(image.device)\n\n        targets = opt_tokens.input_ids.masked_fill(\n            opt_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100\n        )\n        if self.prompt:\n            targets[:, : self.prompt_length] = -100  # do not apply loss to the prompt\n\n        empty_targets = (\n            torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)\n        )\n        targets = torch.cat([empty_targets, targets], dim=1)\n\n        inputs_embeds = self.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)\n        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)\n        attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)\n\n        outputs = self.opt_model(\n            inputs_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            return_dict=True,\n            labels=targets,\n        )\n        loss = outputs.loss\n\n        return {\"loss\": loss}\n\n    @torch.no_grad()\n    def generate(\n        self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=5,\n        max_length=30,\n        min_length=1,\n        top_p=0.9,\n        repetition_penalty=1.0,\n        length_penalty=1.0,\n        num_captions=1,\n        temperature=1,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        if 'video' in samples:\n            image = samples[\"video\"]\n            vid = samples['vid']\n            fids = samples['fids']\n            out = {}\n            #print('vid', vid)\n            #print('fids', fids)\n            b, t, c, w, h = image.shape\n            image = image.reshape(-1, c, w, h)\n        else:\n            image = samples[\"image\"]\n            \n        with torch.cuda.amp.autocast(\n            enabled=(self.device != torch.device(\"cpu\"))\n        ):          \n            image_embeds = self.ln_vision(self.visual_encoder(image))\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n                image.device\n            )\n\n            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n            query_output = self.Qformer.bert(\n                query_embeds=query_tokens,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n\n            inputs_opt = self.opt_proj(query_output.last_hidden_state)\n            atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)\n\n            if \"prompt\" in samples.keys():\n                prompt = samples[\"prompt\"]\n            else:\n                prompt = self.prompt\n\n            prompt = [prompt] * image.size(0)\n\n            opt_tokens = self.opt_tokenizer(prompt, return_tensors=\"pt\").to(image.device)\n            input_ids = opt_tokens.input_ids\n            attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)\n\n            if use_nucleus_sampling:\n                query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0)\n                num_beams = 1\n            else:\n                query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0)\n\n            outputs = self.opt_model.generate(\n                input_ids=input_ids,\n                query_embeds=query_embeds,\n                attention_mask=attention_mask,\n                do_sample=use_nucleus_sampling,\n                top_p=top_p,\n                temperature=temperature,\n                num_beams=num_beams,\n                max_new_tokens=max_length,\n                min_length=min_length,\n                eos_token_id=self.eos_token_id,\n                repetition_penalty=repetition_penalty,\n                length_penalty=length_penalty,\n                num_return_sequences=num_captions,\n            )\n\n            prompt_length = opt_tokens.input_ids.shape[1]\n            output_text = self.opt_tokenizer.batch_decode(\n                outputs[:, prompt_length:], skip_special_tokens=True\n            )\n            output_text = [text.strip() for text in output_text]\n            \n            if 'video' in samples:\n                out['vid'] = vid\n                out['fids'] = fids\n                caption_by_batch = []\n                for i in range(b):\n                    caption_by_batch.append(output_text[i*t : (i+1)*t])\n                out['output_text'] = caption_by_batch\n                return out\n            else:\n                return output_text\n\n    @classmethod\n    def from_config(cls, cfg):\n\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n        opt_model = cfg.get(\"opt_model\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n\n        model = cls(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            opt_model=opt_model,\n            prompt=prompt,\n            max_txt_len=max_txt_len,\n        )\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip2_models/blip2_qformer.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom torch.cuda.amp import autocast as autocast\nfrom torch.nn import functional as F\n\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import all_gather_with_grad, concat_all_gather\nfrom lavis.models.blip2_models.blip2 import (\n    Blip2Base,\n    compute_sim_matrix,\n    disabled_train,\n)\nfrom lavis.models.blip_models.blip_outputs import BlipOutput, BlipOutputFeatures\n\n\n@registry.register_model(\"blip2\")\n@registry.register_model(\"blip2_feature_extractor\")\nclass Blip2Qformer(Blip2Base):\n    \"\"\"\n    BLIP2 first-stage model with Q-former and ViT.\n    Supported model types:\n        - pretrained: pretrained model\n        - coco: fintuned model on coco\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2\", \"pretrain\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"pretrain\": \"configs/models/blip2/blip2_pretrain.yaml\",\n        \"coco\": \"configs/models/blip2/blip2_coco.yaml\",\n    }\n\n    def __init__(\n        self,\n        img_size=224,\n        drop_path_rate=0,\n        use_grad_checkpoint=False,\n        vit_precision=\"fp16\",\n        freeze_vit=True,\n        num_query_token=32,\n        embed_dim=256,\n        max_txt_len=32,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder, self.ln_vision = self.init_vision_encoder(\n            img_size, drop_path_rate, use_grad_checkpoint, vit_precision\n        )\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                param.requires_grad = False                \n            self.visual_encoder = self.visual_encoder.eval()\n            self.visual_encoder.train = disabled_train            \n            logging.info(\"freeze vision encoder\")\n        self.Qformer, self.query_tokens = self.init_Qformer(\n            num_query_token, self.visual_encoder.num_features\n        )\n        self.Qformer.resize_token_embeddings(len(self.tokenizer))\n        state_dict = self.Qformer.state_dict()\n        for name, param in self.Qformer.named_parameters():\n            if \"_query\" in name:\n                key_orig = name.replace(\"_query\", \"\")\n                param.data.copy_(state_dict[key_orig])\n\n        self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)\n        self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)\n\n        self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)\n\n        self.temp = nn.Parameter(0.07 * torch.ones([]))\n\n        self.max_txt_len = max_txt_len\n\n    def forward(self, samples):\n        image = samples[\"image\"]\n        text = samples[\"text_input\"]\n        \n        image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            use_cache=True,\n            return_dict=True,\n        )\n\n        image_feats = F.normalize(\n            self.vision_proj(query_output.last_hidden_state), dim=-1\n        )\n\n        text_tokens = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(image.device)\n        text_output = self.Qformer.bert(\n            text_tokens.input_ids,\n            attention_mask=text_tokens.attention_mask,\n            return_dict=True,\n        )\n        text_feat = F.normalize(\n            self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1\n        )\n\n        ###============== Image-text Contrastive ===================###\n        image_feats_all = concat_all_gather(\n            image_feats\n        )  # [batch_size*num_gpu, num_query_tokens, embed_dim]\n        text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]\n\n        sim_q2t = torch.matmul(\n            image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)\n        ).squeeze()\n        # [batch_size, batch_size*num_gpu, num_query_tokens]\n\n        # image-text similarity: aggregate across all query tokens\n        sim_i2t, _ = sim_q2t.max(-1)\n        sim_i2t = sim_i2t / self.temp\n\n        # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]\n        sim_t2q = torch.matmul(\n            text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)\n        ).squeeze()\n\n        # text-image similarity: aggregate across all query tokens\n        sim_t2i, _ = sim_t2q.max(-1)\n        sim_t2i = sim_t2i / self.temp  # [batch_size, batch_size*num_gpu]\n\n        rank = dist.get_rank()\n        bs = image.size(0)\n        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(\n            image.device\n        )\n\n        loss_itc = (\n            F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)\n            + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)\n        ) / 2\n\n        ###============== Image-text Matching ===================###\n        text_input_ids_world = concat_all_gather(text_tokens.input_ids)\n        text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)\n        image_embeds_world = all_gather_with_grad(image_embeds)\n        with torch.no_grad():\n            weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4\n            weights_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(0)\n            weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4\n            weights_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(0)\n\n        # select a negative image for each text\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n            image_embeds_neg.append(image_embeds_world[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n        # select a negative text for each image\n        text_ids_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n            text_ids_neg.append(text_input_ids_world[neg_idx])\n            text_atts_neg.append(text_attention_mask_world[neg_idx])\n\n        text_ids_neg = torch.stack(text_ids_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_ids_all = torch.cat(\n            [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0\n        )  # pos, pos, neg\n        text_atts_all = torch.cat(\n            [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],\n            dim=0,\n        )\n\n        query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)\n        query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n        attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)\n\n        image_embeds_all = torch.cat(\n            [image_embeds, image_embeds_neg, image_embeds], dim=0\n        )  # pos, neg, pos\n        image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        output_itm = self.Qformer.bert(\n            text_ids_all,\n            query_embeds=query_tokens_itm,\n            attention_mask=attention_mask_all,\n            encoder_hidden_states=image_embeds_all,\n            encoder_attention_mask=image_atts_all,\n            return_dict=True,\n        )\n\n        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]\n        vl_output = self.itm_head(vl_embeddings)\n        logits = vl_output.mean(dim=1)\n\n        itm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(image.device)\n        loss_itm = F.cross_entropy(logits, itm_labels)\n\n        ##================= Image Captioning ========================##\n        decoder_input_ids = text_tokens.input_ids.clone()\n        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id\n        labels = decoder_input_ids.masked_fill(\n            decoder_input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n        attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)\n        lm_output = self.Qformer(\n            decoder_input_ids,\n            attention_mask=attention_mask,\n            past_key_values=query_output.past_key_values,\n            return_dict=True,\n            labels=labels,\n        )\n        \n        loss_lm = lm_output.loss\n\n        return BlipOutput(\n            loss=loss_itc + loss_itm + loss_lm,\n            loss_itc=loss_itc,\n            loss_itm=loss_itm,\n            loss_lm=loss_lm,\n        )\n\n    @torch.no_grad()\n    def generate(\n        self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=3,\n        max_length=30,\n        min_length=10,\n        top_p=0.9,\n        repetition_penalty=1.0,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        image = samples[\"image\"]\n        image_embeds = self.ln_vision(self.visual_encoder(image))\n\n        if not use_nucleus_sampling:\n            image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)\n        else:\n            num_beams = 1\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        model_kwargs = {\n            \"encoder_hidden_states\": image_embeds,\n            \"encoder_attention_mask\": image_atts,\n        }\n\n        input_ids = (\n            torch.LongTensor(image.size(0), 1)\n            .fill_(self.tokenizer.bos_token_id)\n            .to(image.device)\n        )\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n\n        outputs = self.Qformer.generate(\n            input_ids=input_ids,\n            query_embeds=query_tokens,\n            max_length=max_length,\n            min_length=min_length,\n            num_beams=num_beams,\n            do_sample=use_nucleus_sampling,\n            top_p=top_p,\n            eos_token_id=self.tokenizer.sep_token_id,\n            pad_token_id=self.tokenizer.pad_token_id,\n            **model_kwargs\n        )\n        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n        return captions\n\n    def forward_image(self, image):\n        image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n        return query_output.last_hidden_state, image_embeds\n\n    def forward_text(self, text_tokens):\n        text_output = self.Qformer.bert(\n            text_tokens.input_ids,\n            attention_mask=text_tokens.attention_mask,\n            return_dict=True,\n        )\n        return text_output.last_hidden_state[:, 0, :]\n\n    def compute_itm(self, image_inputs, text_ids, text_atts):\n        image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(\n            image_inputs.device\n        )\n        query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1)\n        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(\n            image_inputs.device\n        )\n        attention_mask = torch.cat([query_atts, text_atts], dim=1)\n        output_itm = self.Qformer.bert(\n            text_ids,\n            query_embeds=query_tokens,\n            attention_mask=attention_mask,\n            encoder_hidden_states=image_inputs,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]\n        itm_logit = self.itm_head(vl_embeddings)\n        itm_logit = itm_logit[:, :, 1].mean(dim=1)\n        return itm_logit\n\n    @torch.no_grad()\n    def extract_features(self, samples, mode=\"multimodal\"):\n        \"\"\"\n        Extract features for multimodal or unimodal samples.\n        Args:\n            samples (dict): A dictionary of samples, containing the following keys:\n                - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.\n                    Raw images should be preprocessed before being passed to feature extractor.\n                - text_input (list): A list of strings containing the text, length B.\n            mode (str): The mode of feature extraction. Can be either \"multimodal\", \"text\" or \"image\".\n                If \"multimodal\", return image features and multimodal features;\n                if \"text\", return text features;\n                if \"image\", return image features.\n                Default: \"multimodal\".\n        Returns:\n            BlipOutputFeatures: A BlipOutputFeatures object containing the features.\n                See lavis/models/blip_models/blip_outputs.py for more details.\n        \"\"\"\n        image = samples.get(\"image\")\n        caption = samples.get(\"text_input\")\n\n        # assert mode is one of \"image\", \"text\", \"multimodal\"\n        assert mode in [\n            \"image\",\n            \"text\",\n            \"multimodal\",\n        ], \"mode must be one of 'image', 'text', 'multimodal'\"\n\n        # initalize output\n        image_embeds, text_embeds, multimodal_embeds = None, None, None\n        image_features, text_features = None, None\n\n        if mode == \"image\":\n            assert (\n                image is not None\n            ), \"Image is not provided for mode 'image' or 'multimodal'\"\n            # return query features\n            image_embeds_frozen = self.ln_vision(self.visual_encoder(image))\n            image_atts = torch.ones(\n                image_embeds_frozen.size()[:-1], dtype=torch.long\n            ).to(self.device)\n            query_tokens = self.query_tokens.expand(\n                image_embeds_frozen.shape[0], -1, -1\n            )\n\n            query_output = self.Qformer.bert(\n                query_embeds=query_tokens,\n                encoder_hidden_states=image_embeds_frozen,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n            image_embeds = query_output.last_hidden_state\n            image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)\n\n        elif mode == \"text\":\n            assert (\n                caption is not None\n            ), \"text input is None for mode 'text' or 'multimodal'\"\n\n            # return text features\n            text = self.tokenizer(caption, return_tensors=\"pt\", padding=True).to(\n                self.device\n            )\n\n            text_output = self.Qformer.bert(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n            )\n            text_embeds = text_output.last_hidden_state\n            text_features = self.text_proj(text_embeds)\n            text_features = F.normalize(text_features, dim=-1)\n\n        elif mode == \"multimodal\":\n            # return multimodel query features\n            image_embeds_frozen = self.ln_vision(self.visual_encoder(image))\n            image_atts = torch.ones(\n                image_embeds_frozen.size()[:-1], dtype=torch.long\n            ).to(self.device)\n            query_tokens = self.query_tokens.expand(\n                image_embeds_frozen.shape[0], -1, -1\n            )\n            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(\n                self.device\n            )\n\n            text = self.tokenizer(caption, return_tensors=\"pt\", padding=True).to(\n                self.device\n            )\n            attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)\n\n            output = self.Qformer.bert(\n                text.input_ids,\n                query_embeds=query_tokens,\n                attention_mask=attention_mask,\n                encoder_hidden_states=image_embeds_frozen,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n\n            multimodal_embeds = output.last_hidden_state[:, : query_tokens.size(1), :]\n\n        return BlipOutputFeatures(\n            image_embeds=image_embeds,\n            image_embeds_proj=image_features,\n            text_embeds=text_embeds,\n            text_embeds_proj=text_features,\n            multimodal_embeds=multimodal_embeds,\n        )\n\n    @classmethod\n    def from_config(cls, cfg):\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n\n        model = cls(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            max_txt_len=max_txt_len,\n        )\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n\n    def compute_sim_matrix(self, data_loader, task_cfg):\n        \"\"\"\n        Compute similarity i2t, t2i matrix for the given data loader.\n        \"\"\"\n        k_test = task_cfg.k_test\n\n        return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)\n"
  },
  {
    "path": "lavis/models/blip2_models/blip2_t5.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\n\nimport torch\nimport torch.nn as nn\nfrom torch.cuda.amp import autocast as autocast\nfrom transformers import T5TokenizerFast\n\nfrom lavis.common.registry import registry\nfrom lavis.models.blip2_models.blip2 import Blip2Base, disabled_train\nfrom lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration\n\n\n@registry.register_model(\"blip2_t5\")\nclass Blip2T5(Blip2Base):\n    \"\"\"\n    BLIP2 T5 model.\n    Supported model types:\n        - pretrain_flant5xl: pretrained model with FlanT5-XL\n        - pretrain_flant5xxl: pretrained model with FlanT5-XXL\n        - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2_t5\", \"pretrain_flant5xl\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"pretrain_flant5xl\": \"configs/models/blip2/blip2_pretrain_flant5xl.yaml\",\n        \"pretrain_flant5xxl\": \"configs/models/blip2/blip2_pretrain_flant5xxl.yaml\",\n        \"caption_coco_flant5xl\": \"configs/models/blip2/blip2_caption_flant5xl.yaml\",\n    }\n\n    def __init__(\n        self,\n        img_size=224,\n        drop_path_rate=0,\n        use_grad_checkpoint=False,\n        vit_precision=\"fp16\",\n        freeze_vit=True,\n        num_query_token=32,\n        t5_model=\"google/flan-t5-xl\",\n        prompt=\"\",\n        max_txt_len=32,\n        apply_lemmatizer=False,\n    ):\n        \"\"\"\n        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.\n        \"\"\"\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder, self.ln_vision = self.init_vision_encoder(\n            img_size, drop_path_rate, use_grad_checkpoint, vit_precision\n        )\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                param.requires_grad = False         \n            self.visual_encoder = self.visual_encoder.eval()\n            self.visual_encoder.train = disabled_train\n            logging.info(\"freeze vision encoder\")\n\n        self.Qformer, self.query_tokens = self.init_Qformer(\n            num_query_token, self.visual_encoder.num_features\n        )\n        self.Qformer.cls = None\n        self.Qformer.bert.embeddings.word_embeddings = None\n        self.Qformer.bert.embeddings.position_embeddings = None\n        for layer in self.Qformer.bert.encoder.layer:\n            layer.output = None\n            layer.intermediate = None\n\n        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)\n        t5_config = T5Config.from_pretrained(t5_model)\n        t5_config.dense_act_fn = \"gelu\"\n        self.t5_model = T5ForConditionalGeneration.from_pretrained(\n            t5_model, config=t5_config\n        )\n\n        for name, param in self.t5_model.named_parameters():\n            param.requires_grad = False\n            param.data = param.data.bfloat16()\n\n        self.t5_proj = nn.Linear(\n            self.Qformer.config.hidden_size, self.t5_model.config.hidden_size\n        )\n\n        self.max_txt_len = max_txt_len\n        self.prompt = prompt\n\n        self._apply_lemmatizer = apply_lemmatizer\n        self._lemmatizer = None\n\n    def forward(self, samples):\n        image = samples[\"image\"]\n        image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_t5 = self.t5_proj(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n\n        with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n            input_tokens = self.t5_tokenizer(\n                samples[\"text_input\"],\n                padding=\"longest\",\n                truncation=True,\n                max_length=self.max_text_length,\n                return_tensors=\"pt\",\n            ).to(image.device)\n            output_tokens = self.t5_tokenizer(\n                samples[\"text_output\"],\n                padding=\"longest\",\n                truncation=True,\n                max_length=self.max_text_length,\n                return_tensors=\"pt\",\n            ).to(image.device)\n\n            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n\n            targets = output_tokens.input_ids.masked_fill(\n                output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100\n            )\n\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n\n            outputs = self.t5_model(\n                inputs_embeds=inputs_embeds,\n                attention_mask=encoder_atts,\n                decoder_attention_mask=output_tokens.attention_mask,\n                return_dict=True,\n                labels=targets,\n            )\n            loss = outputs.loss\n\n            return {\"loss\": loss}\n\n    @torch.no_grad()\n    def generate(\n        self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=5,\n        max_length=30,\n        min_length=1,\n        top_p=0.9,\n        repetition_penalty=1.0,\n        length_penalty=1.0,\n        num_captions=1,\n        temperature=1,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        if 'video' in samples:\n            image = samples[\"video\"]\n            vid = samples['vid']\n            fids = samples['fids']\n            out = {}\n            #print('vid', vid)\n            #print('fids', fids)\n            b, t, c, w, h = image.shape\n            image = image.reshape(-1, c, w, h)\n            #print('prompt', self.prompt)\n        else:\n            image = samples[\"image\"]\n            \n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_t5 = self.t5_proj(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n\n        if \"prompt\" in samples.keys():\n            prompt = samples[\"prompt\"]\n        else:\n            prompt = self.prompt\n\n        if isinstance(prompt, str):\n            prompt = [prompt] * image.size(0)\n        else:\n            assert len(prompt) == image.size(\n                0\n            ), \"The number of prompts must be equal to the batch size.\"\n\n        input_tokens = self.t5_tokenizer(\n            prompt, padding=\"longest\", return_tensors=\"pt\"\n        ).to(image.device)\n\n        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n\n        device_type = \"cuda\" if \"cuda\" in str(self.device) else \"cpu\"\n        with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n\n            outputs = self.t5_model.generate(\n                inputs_embeds=inputs_embeds,\n                attention_mask=encoder_atts,\n                do_sample=use_nucleus_sampling,\n                top_p=top_p,\n                temperature=temperature,\n                num_beams=num_beams,\n                max_new_tokens=max_length,\n                min_length=min_length,\n                repetition_penalty=repetition_penalty,\n                length_penalty=length_penalty,\n                num_return_sequences=num_captions,\n            )\n            output_text = self.t5_tokenizer.batch_decode(\n                outputs, skip_special_tokens=True\n            )\n            \n        if 'video' in samples:\n            out['vid'] = vid\n            out['fids'] = fids\n            caption_by_batch = []\n            for i in range(b):\n                caption_by_batch.append(output_text[i*t : (i+1)*t])\n            out['output_text'] = caption_by_batch\n            return out\n        else:\n            return output_text\n\n    def predict_answers(\n        self,\n        samples,\n        num_beams=5,\n        inference_method=\"generate\",\n        max_len=10,\n        min_len=1,\n        num_ans_candidates=128,\n        answer_list=None,\n        prompt=\"\",\n        length_penalty=-1,\n        **kwargs\n    ):\n        image = samples[\"image\"]\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_t5 = self.t5_proj(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n        if prompt:\n            text_input = [prompt.format(question) for question in samples[\"text_input\"]]\n        else:\n            text_input = samples[\"text_input\"]\n\n        input_tokens = self.t5_tokenizer(\n            text_input, padding=\"longest\", return_tensors=\"pt\"\n        ).to(image.device)\n\n        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n\n        device_type = \"cuda\" if \"cuda\" in str(self.device) else \"cpu\"\n        with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n\n            outputs = self.t5_model.generate(\n                inputs_embeds=inputs_embeds,\n                attention_mask=encoder_atts,\n                do_sample=False,\n                num_beams=num_beams,\n                max_new_tokens=max_len,\n                min_length=min_len,\n                length_penalty=length_penalty,\n            )\n            output_text = self.t5_tokenizer.batch_decode(\n                outputs, skip_special_tokens=True\n            )\n\n        if self._apply_lemmatizer:\n            output_text = self._lemmatize(output_text)\n\n        return output_text\n\n    def _lemmatize(self, answers):\n        def apply(answer):\n            doc = self.lemmatizer(answer)\n\n            words = []\n            for token in doc:\n                if token.pos_ in [\"NOUN\", \"VERB\"]:\n                    words.append(token.lemma_)\n                else:\n                    words.append(token.text)\n            answer = \" \".join(words)\n\n            return answer\n\n        return [apply(answer) for answer in answers]\n\n    @property\n    def lemmatizer(self):\n        if self._lemmatizer is None:\n            try:\n                import spacy\n\n                self._lemmatizer = spacy.load(\"en_core_web_sm\")\n            except ImportError:\n                logging.error(\n                    \"\"\"\n                    Please install spacy and en_core_web_sm model to apply lemmatization.\n                    python -m spacy download en_core_web_sm\n                    OR\n                    import spacy.cli\n                    spacy.cli.download(\"en_core_web_sm\")\n                    \"\"\"\n                )\n                exit(1)\n\n        return self._lemmatizer\n\n    @classmethod\n    def from_config(cls, cfg):\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n        t5_model = cfg.get(\"t5_model\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n\n        apply_lemmatizer = cfg.get(\"apply_lemmatizer\", False)\n\n        model = cls(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            t5_model=t5_model,\n            prompt=prompt,\n            max_txt_len=max_txt_len,\n            apply_lemmatizer=apply_lemmatizer,\n        )\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip2_models/modeling_opt.py",
    "content": "# coding=utf-8\n# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch OPT model.\"\"\"\nimport random\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.models.opt.configuration_opt import OPTConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"facebook/opt-350m\"\n_CONFIG_FOR_DOC = \"OPTConfig\"\n_TOKENIZER_FOR_DOC = \"GPT2Tokenizer\"\n\n# Base model docstring\n_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]\n\n# SequenceClassification docstring\n_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = \"ArthurZ/opt-350m-dummy-sc\"\n_SEQ_CLASS_EXPECTED_LOSS = 1.71\n_SEQ_CLASS_EXPECTED_OUTPUT = \"'LABEL_0'\"\n\n# QuestionAnswering docstring\n_QA_EXPECTED_OUTPUT = \"'a nice puppet'\"\n_QA_EXPECTED_LOSS = 7.41\n_QA_TARGET_START_INDEX = 14\n_QA_TARGET_END_INDEX = 15\n\nOPT_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"facebook/opt-125m\",\n    \"facebook/opt-350m\",\n    \"facebook/opt-1.3b\",\n    \"facebook/opt-2.7b\",\n    \"facebook/opt-6.7b\",\n    \"facebook/opt-13b\",\n    \"facebook/opt-30b\",\n    # See all OPT models at https://huggingface.co/models?filter=opt\n]\n\n\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))\n    mask_cond = torch.arange(mask.size(-1))\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1\n        )\n    return mask[None, None, :, :].expand(\n        bsz, 1, tgt_len, tgt_len + past_key_values_length\n    )\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(\n        inverted_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n\nclass OPTLearnedPositionalEmbedding(nn.Embedding):\n    \"\"\"\n    This module learns positional embeddings up to a fixed maximum size.\n    \"\"\"\n\n    def __init__(self, num_embeddings: int, embedding_dim: int):\n        # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2\n        # and adjust num_embeddings appropriately. Other models don't have this hack\n        self.offset = 2\n        super().__init__(num_embeddings + self.offset, embedding_dim)\n\n    def forward(\n        self, attention_mask: torch.LongTensor, past_key_values_length: int = 0\n    ):\n        \"\"\"`input_ids_shape` is expected to be [bsz x seqlen].\"\"\"\n        attention_mask = attention_mask.long()\n\n        # create positions depending on attention_mask\n        positions = (\n            torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask\n        ).long() - 1\n\n        # cut positions if `past_key_values_length` is > 0\n        positions = positions[:, past_key_values_length:]\n\n        return super().forward(positions + self.offset)\n\n\nclass OPTAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        embed_dim: int,\n        num_heads: int,\n        dropout: float = 0.0,\n        is_decoder: bool = False,\n        bias: bool = True,\n    ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n\n        if (self.head_dim * num_heads) != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}\"\n                f\" and `num_heads`: {num_heads}).\"\n            )\n        self.scaling = self.head_dim**-0.5\n        self.is_decoder = is_decoder\n\n        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        key_value_states: Optional[torch.Tensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        # if key_value_states are provided this layer is used as a cross-attention layer\n        # for the decoder\n        is_cross_attention = key_value_states is not None\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n        query_states = self.q_proj(hidden_states) * self.scaling\n        # get key, value proj\n        if is_cross_attention and past_key_value is not None:\n            # reuse k,v, cross_attentions\n            key_states = past_key_value[0]\n            value_states = past_key_value[1]\n        elif is_cross_attention:\n            # cross_attentions\n            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)\n            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)\n        elif past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n        else:\n            # self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n\n        if self.is_decoder:\n            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.\n            # Further calls to cross_attention layer can then reuse all cross-attention\n            # key/value_states (first \"if\" case)\n            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of\n            # all previous decoder key/value_states. Further calls to uni-directional self-attention\n            # can concat previous decoder key/value_states to current projected key/value_states (third \"elif\" case)\n            # if encoder bi-directional self-attention `past_key_value` is always `None`\n            past_key_value = (key_states, value_states)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = torch.max(\n                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437\n        if attn_weights.dtype == torch.float16:\n            attn_weights = nn.functional.softmax(\n                attn_weights, dim=-1, dtype=torch.float32\n            ).to(torch.float16)\n        else:\n            attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if layer_head_mask is not None:\n            if layer_head_mask.size() != (self.num_heads,):\n                raise ValueError(\n                    f\"Head mask for a single layer should be of size {(self.num_heads,)}, but is\"\n                    f\" {layer_head_mask.size()}\"\n                )\n            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(\n                bsz, self.num_heads, tgt_len, src_len\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if output_attentions:\n            # this operation is a bit awkward, but it's required to\n            # make sure that attn_weights keeps its gradient.\n            # In order to do so, attn_weights have to be reshaped\n            # twice and have to be reused in the following\n            attn_weights_reshaped = attn_weights.view(\n                bsz, self.num_heads, tgt_len, src_len\n            )\n            attn_weights = attn_weights_reshaped.view(\n                bsz * self.num_heads, tgt_len, src_len\n            )\n        else:\n            attn_weights_reshaped = None\n\n        attn_probs = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)\n        attn_output = attn_output.transpose(1, 2)\n\n        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be\n        # partitioned aross GPUs when using tensor-parallelism.\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights_reshaped, past_key_value\n\n\nclass OPTDecoderLayer(nn.Module):\n    def __init__(self, config: OPTConfig):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = OPTAttention(\n            embed_dim=self.embed_dim,\n            num_heads=config.num_attention_heads,\n            dropout=config.attention_dropout,\n            is_decoder=True,\n        )\n        self.do_layer_norm_before = config.do_layer_norm_before\n        self.dropout = config.dropout\n        self.activation_fn = ACT2FN[config.activation_function]\n\n        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)\n        self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)\n        self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)\n        self.final_layer_norm = nn.LayerNorm(self.embed_dim)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_head_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size\n                `(encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n        if self.do_layer_norm_before:\n            hidden_states = self.self_attn_layer_norm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            past_key_value=past_key_value,\n            attention_mask=attention_mask,\n            layer_head_mask=layer_head_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = nn.functional.dropout(\n            hidden_states, p=self.dropout, training=self.training\n        )\n        hidden_states = residual + hidden_states\n\n        # 350m applies layer norm AFTER attention\n        if not self.do_layer_norm_before:\n            hidden_states = self.self_attn_layer_norm(hidden_states)\n\n        # Fully Connected\n        hidden_states_shape = hidden_states.shape\n        hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))\n        residual = hidden_states\n\n        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n        if self.do_layer_norm_before:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = nn.functional.dropout(\n            hidden_states, p=self.dropout, training=self.training\n        )\n\n        hidden_states = (residual + hidden_states).view(hidden_states_shape)\n\n        # 350m applies layer norm AFTER attention\n        if not self.do_layer_norm_before:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nOPT_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`OPTConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare OPT Model outputting raw hidden-states without any specific head on top.\",\n    OPT_START_DOCSTRING,\n)\nclass OPTPreTrainedModel(PreTrainedModel):\n\n    config_class = OPTConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"OPTDecoderLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.init_std\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (OPTDecoder)):\n            module.gradient_checkpointing = value\n\n\nOPT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\nclass OPTDecoder(OPTPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]\n\n    Args:\n        config: OPTConfig\n    \"\"\"\n\n    def __init__(self, config: OPTConfig):\n        super().__init__(config)\n        self.dropout = config.dropout\n        self.layerdrop = config.layerdrop\n        self.padding_idx = config.pad_token_id\n        self.max_target_positions = config.max_position_embeddings\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.word_embed_proj_dim, self.padding_idx\n        )\n        self.embed_positions = OPTLearnedPositionalEmbedding(\n            config.max_position_embeddings, config.hidden_size\n        )\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_out = nn.Linear(\n                config.hidden_size, config.word_embed_proj_dim, bias=False\n            )\n        else:\n            self.project_out = None\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_in = nn.Linear(\n                config.word_embed_proj_dim, config.hidden_size, bias=False\n            )\n        else:\n            self.project_in = None\n\n        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility\n        # with checkpoints that have been fine-tuned before transformers v4.20.1\n        # see https://github.com/facebookresearch/metaseq/pull/164\n        if config.do_layer_norm_before and not config._remove_final_layer_norm:\n            self.final_layer_norm = nn.LayerNorm(config.hidden_size)\n        else:\n            self.final_layer_norm = None\n\n        self.layers = nn.ModuleList(\n            [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(\n        self, attention_mask, input_shape, inputs_embeds, past_key_values_length\n    ):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                past_key_values_length=past_key_values_length,\n            ).to(inputs_embeds.device)\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(\n                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]\n            ).to(inputs_embeds.device)\n            combined_attention_mask = (\n                expanded_attn_mask\n                if combined_attention_mask is None\n                else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        query_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):\n                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\n                \"You have to specify either decoder_input_ids or decoder_inputs_embeds\"\n            )\n\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if query_embeds is not None:\n            inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)\n            input_shape = inputs_embeds.size()[:-1]\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device\n            )\n        pos_embeds = self.embed_positions(attention_mask, past_key_values_length)\n\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, input_shape, inputs_embeds, past_key_values_length\n        )\n\n        if self.project_in is not None:\n            inputs_embeds = self.project_in(inputs_embeds)\n\n        hidden_states = inputs_embeds + pos_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        # check if head_mask has a correct number of layers specified if desired\n        for attn_mask, mask_name in zip([head_mask], [\"head_mask\"]):\n            if attn_mask is not None:\n                if attn_mask.size()[0] != (len(self.layers)):\n                    raise ValueError(\n                        f\"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for\"\n                        f\" {head_mask.size()[0]}.\"\n                    )\n\n        for idx, decoder_layer in enumerate(self.layers):\n            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            dropout_probability = random.uniform(0, 1)\n            if self.training and (dropout_probability < self.layerdrop):\n                continue\n\n            past_key_value = (\n                past_key_values[idx] if past_key_values is not None else None\n            )\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    head_mask[idx] if head_mask is not None else None,\n                    None,\n                )\n            else:\n\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        if self.final_layer_norm is not None:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        if self.project_out is not None:\n            hidden_states = self.project_out(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\n@add_start_docstrings(\n    \"The bare OPT Model outputting raw hidden-states without any specific head on top.\",\n    OPT_START_DOCSTRING,\n)\nclass OPTModel(OPTPreTrainedModel):\n    def __init__(self, config: OPTConfig):\n        super().__init__(config)\n        self.decoder = OPTDecoder(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.decoder.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.decoder.embed_tokens = value\n\n    def get_decoder(self):\n        return self.decoder\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        processor_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutputWithPast,\n        config_class=_CONFIG_FOR_DOC,\n        expected_output=_EXPECTED_OUTPUT_SHAPE,\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        query_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\n        decoder_outputs = self.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            query_embeds=query_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        if not return_dict:\n            return decoder_outputs\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=decoder_outputs.last_hidden_state,\n            past_key_values=decoder_outputs.past_key_values,\n            hidden_states=decoder_outputs.hidden_states,\n            attentions=decoder_outputs.attentions,\n        )\n\n\nclass OPTForCausalLM(OPTPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = OPTModel(config)\n\n        # the lm_head weight is automatically tied to the embed tokens weight\n        self.lm_head = nn.Linear(\n            config.word_embed_proj_dim, config.vocab_size, bias=False\n        )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.decoder.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.decoder.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model.decoder = decoder\n\n    def get_decoder(self):\n        return self.model.decoder\n\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        query_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        reduction: Optional[str] = \"mean\",\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):\n                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional\n                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import GPT2Tokenizer, OPTForCausalLM\n\n        >>> model = OPTForCausalLM.from_pretrained(\"facebook/opt-350m\")\n        >>> tokenizer = GPT2Tokenizer.from_pretrained(\"facebook/opt-350m\")\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            query_embeds=query_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        logits = self.lm_head(outputs[0]).contiguous()\n\n        loss = None\n        if labels is not None:\n            logits = logits[:, -labels.size(1) :, :]\n\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss(reduction=reduction)\n            loss = loss_fct(\n                shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)\n            )\n            if reduction == \"none\":\n                loss = loss.view(shift_logits.size(0), -1).sum(1)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids=None,\n        query_embeds=None,\n        past=None,\n        attention_mask=None,\n        use_cache=None,\n        **kwargs,\n    ):\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            if input_ids is not None:\n                attention_mask = input_ids.new_ones(input_ids.shape)\n        if past:\n            input_ids = input_ids[:, -1:]\n            query_embeds = None\n        # first step, decoder_cached_states are empty\n        return {\n            \"input_ids\": input_ids,\n            \"query_embeds\": query_embeds,\n            \"attention_mask\": attention_mask,\n            \"past_key_values\": past,\n            \"use_cache\": use_cache,\n        }\n\n    @staticmethod\n    def _reorder_cache(past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n"
  },
  {
    "path": "lavis/models/blip2_models/modeling_t5.py",
    "content": "# coding=utf-8\n# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch T5 model.\"\"\"\n\n\nimport copy\nimport math\nimport os\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nfrom torch.utils.checkpoint import checkpoint\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutput,\n    BaseModelOutputWithPastAndCrossAttentions,\n    Seq2SeqLMOutput,\n    Seq2SeqModelOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.utils import (\n    DUMMY_INPUTS,\n    DUMMY_MASK,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_torch_fx_proxy,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.model_parallel_utils import assert_device_map, get_device_map\nfrom transformers.models.t5.configuration_t5 import T5Config\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"T5Config\"\n_TOKENIZER_FOR_DOC = \"T5Tokenizer\"\n_CHECKPOINT_FOR_DOC = \"t5-small\"\n\n####################################################\n# This dict contains ids and associated url\n# for the pretrained weights provided with the models\n####################################################\nT5_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"t5-small\",\n    \"t5-base\",\n    \"t5-large\",\n    \"t5-3b\",\n    \"t5-11b\",\n    # See all T5 models at https://huggingface.co/models?filter=t5\n]\n\n\n####################################################\n# This is a conversion method from TF 1.0 to PyTorch\n# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28\n####################################################\ndef load_tf_weights_in_t5(model, config, tf_checkpoint_path):\n    \"\"\"Load tf checkpoints in a pytorch model.\"\"\"\n    try:\n        import re\n\n        import numpy as np\n        import tensorflow as tf\n    except ImportError:\n        logger.error(\n            \"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see \"\n            \"https://www.tensorflow.org/install/ for installation instructions.\"\n        )\n        raise\n    tf_path = os.path.abspath(tf_checkpoint_path)\n    logger.info(f\"Converting TensorFlow checkpoint from {tf_path}\")\n    # Load weights from TF model\n    init_vars = tf.train.list_variables(tf_path)\n    names = []\n    tf_weights = {}\n    for name, shape in init_vars:\n        logger.info(f\"Loading TF weight {name} with shape {shape}\")\n        array = tf.train.load_variable(tf_path, name)\n        names.append(name)\n        tf_weights[name] = array\n\n    for txt_name in names:\n        name = txt_name.split(\"/\")\n        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v\n        # which are not required for using pretrained model\n        if any(\n            n\n            in [\n                \"adam_v\",\n                \"adam_m\",\n                \"AdamWeightDecayOptimizer\",\n                \"AdamWeightDecayOptimizer_1\",\n                \"global_step\",\n            ]\n            for n in name\n        ):\n            logger.info(f\"Skipping {'/'.join(name)}\")\n            tf_weights.pop(txt_name, None)\n            continue\n        if \"_slot_\" in name[-1]:\n            logger.info(f\"Skipping {'/'.join(name)}\")\n            tf_weights.pop(txt_name, None)\n            continue\n        pointer = model\n        array = tf_weights[txt_name]\n\n        for m_name in name:\n            if re.fullmatch(r\"[A-Za-z]+_\\d+\", m_name):\n                scope_names = re.split(r\"_(\\d+)\", m_name)\n            else:\n                scope_names = [m_name]\n            if scope_names[0] in [\"kernel\", \"scale\", \"embedding\"]:\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"self_attention\":\n                pointer = getattr(pointer, \"layer\")\n                pointer = pointer[0]\n            elif scope_names[0] == \"enc_dec_attention\":\n                pointer = getattr(pointer, \"layer\")\n                pointer = pointer[1]\n            elif scope_names[0] == \"dense_relu_dense\":\n                pointer = getattr(pointer, \"layer\")\n                pointer = pointer[2]\n            elif scope_names[0] == \"rms_norm\":\n                if hasattr(pointer, \"layer_norm\"):\n                    pointer = getattr(pointer, \"layer_norm\")\n                elif hasattr(pointer, \"final_layer_norm\"):\n                    pointer = getattr(pointer, \"final_layer_norm\")\n            elif scope_names[0] == \"scale\":\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"output_bias\" or scope_names[0] == \"beta\":\n                pointer = getattr(pointer, \"bias\")\n            elif scope_names[0] == \"squad\":\n                pointer = getattr(pointer, \"classifier\")\n            elif scope_names[0] == \"decoder\" and name[1] == \"logits\":\n                continue\n            elif scope_names[0] == \"logits\":\n                pointer = getattr(pointer, \"lm_head\")\n            elif (\n                scope_names[0] == \"wi\"\n                and len(scope_names) > 1\n                and scope_names[1].isdigit()\n            ):\n                pointer = getattr(pointer, f\"wi_{scope_names[1]}\")\n                continue\n            else:\n                try:\n                    pointer = getattr(pointer, scope_names[0])\n                except AttributeError:\n                    logger.info(f\"Skipping {'/'.join(name)}\")\n                    continue\n            if len(scope_names) >= 2:\n                num = int(scope_names[1])\n                pointer = pointer[num]\n        if scope_names[0] not in [\"kernel\", \"scale\", \"embedding\"]:\n            pointer = getattr(pointer, \"weight\")\n        if scope_names[0] != \"embedding\":\n            logger.info(f\"Transposing numpy weight of shape {array.shape} for {name}\")\n            array = np.transpose(array)\n        try:\n            assert (\n                pointer.shape == array.shape\n            ), f\"Pointer shape {pointer.shape} and array shape {array.shape} mismatched\"\n        except AssertionError as e:\n            e.args += (pointer.shape, array.shape)\n            raise\n        logger.info(f\"Initialize PyTorch weight {name}\")\n        pointer.data = torch.from_numpy(array.astype(np.float32))\n        tf_weights.pop(txt_name, None)\n\n    logger.info(f\"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.\")\n    return model\n\n\n####################################################\n# PyTorch Models are constructed by sub-classing\n# - torch.nn.Module for the layers and\n# - PreTrainedModel for the models (it-self a sub-class of nn.Module)\n####################################################\nPARALLELIZE_DOCSTRING = r\"\"\"\n    This is an experimental feature and is a subject to change at a moment's notice.\n\n    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,\n    it will evenly distribute blocks across all devices.\n\n    Args:\n        device_map (`Dict[int, list]`, optional, defaults to None):\n            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always\n            automatically mapped to the first device (for esoteric reasons). That means that the first device should\n            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the\n            following number of attention modules:\n\n                - t5-small: 6\n                - t5-base: 12\n                - t5-large: 24\n                - t5-3b: 24\n                - t5-11b: 24\n\n    Example:\n\n    ```python\n    # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:\n    model = T5ForConditionalGeneration.from_pretrained(\"t5-3b\")\n    device_map = {\n        0: [0, 1, 2],\n        1: [3, 4, 5, 6, 7, 8, 9],\n        2: [10, 11, 12, 13, 14, 15, 16],\n        3: [17, 18, 19, 20, 21, 22, 23],\n    }\n    model.parallelize(device_map)\n    ```\n\"\"\"\nDEPARALLELIZE_DOCSTRING = r\"\"\"\n    Moves the model to cpu from a model parallel state.\n\n    Example:\n\n    ```python\n    # On a 4 GPU machine with t5-3b:\n    model = T5ForConditionalGeneration.from_pretrained(\"t5-3b\")\n    device_map = {\n        0: [0, 1, 2],\n        1: [3, 4, 5, 6, 7, 8, 9],\n        2: [10, 11, 12, 13, 14, 15, 16],\n        3: [17, 18, 19, 20, 21, 22, 23],\n    }\n    model.parallelize(device_map)  # Splits the model across several devices\n    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()\n    ```\n\"\"\"\n\n\nclass T5LayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n\n        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean\n        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated\n        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for\n        # half-precision inputs is done in fp32\n\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n        # convert into half-precision if necessary\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\ntry:\n    from apex.normalization import FusedRMSNorm\n\n    T5LayerNorm = FusedRMSNorm  # noqa\n\n    logger.info(\n        \"Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm\"\n    )\nexcept ImportError:\n    # using the normal T5LayerNorm\n    pass\nexcept Exception:\n    logger.warning(\"discovered apex but it failed to load, falling back to T5LayerNorm\")\n    pass\n\nALL_LAYERNORM_LAYERS.append(T5LayerNorm)\n\n\nclass T5DenseActDense(nn.Module):\n    def __init__(self, config: T5Config):\n        super().__init__()\n        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)\n        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)\n        self.dropout = nn.Dropout(config.dropout_rate)\n        self.act = ACT2FN[config.dense_act_fn]\n\n    def forward(self, hidden_states):\n        hidden_states = self.wi(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.wo(hidden_states)\n        return hidden_states\n\n\nclass T5DenseGatedActDense(nn.Module):\n    def __init__(self, config: T5Config):\n        super().__init__()\n        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)\n        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)\n        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)\n        self.dropout = nn.Dropout(config.dropout_rate)\n        self.act = ACT2FN[config.dense_act_fn]\n\n    def forward(self, hidden_states):\n        hidden_gelu = self.act(self.wi_0(hidden_states))\n        hidden_linear = self.wi_1(hidden_states)\n        hidden_states = hidden_gelu * hidden_linear\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.wo(hidden_states)\n        return hidden_states\n\n\nclass T5LayerFF(nn.Module):\n    def __init__(self, config: T5Config):\n        super().__init__()\n        if config.is_gated_act:\n            self.DenseReluDense = T5DenseGatedActDense(config)\n        else:\n            self.DenseReluDense = T5DenseActDense(config)\n\n        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(self, hidden_states):\n        forwarded_states = self.layer_norm(hidden_states)\n        forwarded_states = self.DenseReluDense(forwarded_states)\n        hidden_states = hidden_states + self.dropout(forwarded_states)\n        return hidden_states\n\n\nclass T5Attention(nn.Module):\n    def __init__(self, config: T5Config, has_relative_attention_bias=False):\n        super().__init__()\n        self.is_decoder = config.is_decoder\n        self.has_relative_attention_bias = has_relative_attention_bias\n        self.relative_attention_num_buckets = config.relative_attention_num_buckets\n        self.relative_attention_max_distance = config.relative_attention_max_distance\n        self.d_model = config.d_model\n        self.key_value_proj_dim = config.d_kv\n        self.n_heads = config.num_heads\n        self.dropout = config.dropout_rate\n        self.inner_dim = self.n_heads * self.key_value_proj_dim\n\n        # Mesh TensorFlow initialization to avoid scaling before softmax\n        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)\n        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)\n        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)\n        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)\n\n        if self.has_relative_attention_bias:\n            self.relative_attention_bias = nn.Embedding(\n                self.relative_attention_num_buckets, self.n_heads\n            )\n        self.pruned_heads = set()\n        self.gradient_checkpointing = False\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads\n        )\n        # Prune linear layers\n        self.q = prune_linear_layer(self.q, index)\n        self.k = prune_linear_layer(self.k, index)\n        self.v = prune_linear_layer(self.v, index)\n        self.o = prune_linear_layer(self.o, index, dim=1)\n        # Update hyper params\n        self.n_heads = self.n_heads - len(heads)\n        self.inner_dim = self.key_value_proj_dim * self.n_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    @staticmethod\n    def _relative_position_bucket(\n        relative_position, bidirectional=True, num_buckets=32, max_distance=128\n    ):\n        \"\"\"\n        Adapted from Mesh Tensorflow:\n        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593\n\n        Translate relative position to a bucket number for relative attention. The relative position is defined as\n        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to\n        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for\n        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative\n        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.\n        This should allow for more graceful generalization to longer sequences than the model has been trained on\n\n        Args:\n            relative_position: an int32 Tensor\n            bidirectional: a boolean - whether the attention is bidirectional\n            num_buckets: an integer\n            max_distance: an integer\n\n        Returns:\n            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)\n        \"\"\"\n        relative_buckets = 0\n        if bidirectional:\n            num_buckets //= 2\n            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets\n            relative_position = torch.abs(relative_position)\n        else:\n            relative_position = -torch.min(\n                relative_position, torch.zeros_like(relative_position)\n            )\n        # now relative_position is in the range [0, inf)\n\n        # half of the buckets are for exact increments in positions\n        max_exact = num_buckets // 2\n        is_small = relative_position < max_exact\n\n        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n        relative_position_if_large = max_exact + (\n            torch.log(relative_position.float() / max_exact)\n            / math.log(max_distance / max_exact)\n            * (num_buckets - max_exact)\n        ).to(torch.long)\n        relative_position_if_large = torch.min(\n            relative_position_if_large,\n            torch.full_like(relative_position_if_large, num_buckets - 1),\n        )\n\n        relative_buckets += torch.where(\n            is_small, relative_position, relative_position_if_large\n        )\n        return relative_buckets\n\n    def compute_bias(self, query_length, key_length, device=None):\n        \"\"\"Compute binned relative position bias\"\"\"\n        if device is None:\n            device = self.relative_attention_bias.weight.device\n        context_position = torch.arange(query_length, dtype=torch.long, device=device)[\n            :, None\n        ]\n        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[\n            None, :\n        ]\n        relative_position = (\n            memory_position - context_position\n        )  # shape (query_length, key_length)\n        relative_position_bucket = self._relative_position_bucket(\n            relative_position,  # shape (query_length, key_length)\n            bidirectional=(not self.is_decoder),\n            num_buckets=self.relative_attention_num_buckets,\n            max_distance=self.relative_attention_max_distance,\n        )\n        values = self.relative_attention_bias(\n            relative_position_bucket\n        )  # shape (query_length, key_length, num_heads)\n        values = values.permute([2, 0, 1]).unsqueeze(\n            0\n        )  # shape (1, num_heads, query_length, key_length)\n        return values\n\n    def forward(\n        self,\n        hidden_states,\n        mask=None,\n        key_value_states=None,\n        position_bias=None,\n        past_key_value=None,\n        layer_head_mask=None,\n        query_length=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        \"\"\"\n        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).\n        \"\"\"\n        # Input is (batch_size, seq_length, dim)\n        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)\n        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)\n        batch_size, seq_length = hidden_states.shape[:2]\n\n        real_seq_length = seq_length\n\n        if past_key_value is not None:\n            assert (\n                len(past_key_value) == 2\n            ), f\"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states\"\n            real_seq_length += (\n                past_key_value[0].shape[2] if query_length is None else query_length\n            )\n\n        key_length = (\n            real_seq_length if key_value_states is None else key_value_states.shape[1]\n        )\n\n        def shape(states):\n            \"\"\"projection\"\"\"\n            return states.view(\n                batch_size, -1, self.n_heads, self.key_value_proj_dim\n            ).transpose(1, 2)\n\n        def unshape(states):\n            \"\"\"reshape\"\"\"\n            return (\n                states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)\n            )\n\n        def project(hidden_states, proj_layer, key_value_states, past_key_value):\n            \"\"\"projects hidden states correctly to key/query states\"\"\"\n            if key_value_states is None:\n                # self-attn\n                # (batch_size, n_heads, seq_length, dim_per_head)\n                hidden_states = shape(proj_layer(hidden_states))\n            elif past_key_value is None:\n                # cross-attn\n                # (batch_size, n_heads, seq_length, dim_per_head)\n                hidden_states = shape(proj_layer(key_value_states))\n\n            if past_key_value is not None:\n                if key_value_states is None:\n                    # self-attn\n                    # (batch_size, n_heads, key_length, dim_per_head)\n                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)\n                else:\n                    # cross-attn\n                    hidden_states = past_key_value\n            return hidden_states\n\n        # get query states\n        query_states = shape(\n            self.q(hidden_states)\n        )  # (batch_size, n_heads, seq_length, dim_per_head)\n\n        # get key/value states\n        key_states = project(\n            hidden_states,\n            self.k,\n            key_value_states,\n            past_key_value[0] if past_key_value is not None else None,\n        )\n        value_states = project(\n            hidden_states,\n            self.v,\n            key_value_states,\n            past_key_value[1] if past_key_value is not None else None,\n        )\n\n        # compute scores\n        scores = torch.matmul(\n            query_states, key_states.transpose(3, 2)\n        )  # equivalent of torch.einsum(\"bnqd,bnkd->bnqk\", query_states, key_states), compatible with onnx op>9\n\n        if position_bias is None:\n            if not self.has_relative_attention_bias:\n                position_bias = torch.zeros(\n                    (1, self.n_heads, real_seq_length, key_length),\n                    device=scores.device,\n                    dtype=scores.dtype,\n                )\n                if self.gradient_checkpointing and self.training:\n                    position_bias.requires_grad = True\n            else:\n                position_bias = self.compute_bias(\n                    real_seq_length, key_length, device=scores.device\n                )\n\n            # if key and values are already calculated\n            # we want only the last query position bias\n            if past_key_value is not None:\n                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]\n\n            if mask is not None:\n                position_bias = (\n                    position_bias + mask\n                )  # (batch_size, n_heads, seq_length, key_length)\n\n        if self.pruned_heads:\n            mask = torch.ones(position_bias.shape[1])\n            mask[list(self.pruned_heads)] = 0\n            position_bias_masked = position_bias[:, mask.bool()]\n        else:\n            position_bias_masked = position_bias\n\n        scores += position_bias_masked\n        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(\n            scores\n        )  # (batch_size, n_heads, seq_length, key_length)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )  # (batch_size, n_heads, seq_length, key_length)\n\n        # Mask heads if we want to\n        if layer_head_mask is not None:\n            attn_weights = attn_weights * layer_head_mask\n\n        attn_output = unshape(\n            torch.matmul(attn_weights, value_states)\n        )  # (batch_size, seq_length, dim)\n        attn_output = self.o(attn_output)\n\n        present_key_value_state = (\n            (key_states, value_states) if (self.is_decoder and use_cache) else None\n        )\n        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)\n\n        if output_attentions:\n            outputs = outputs + (attn_weights,)\n        return outputs\n\n\nclass T5LayerSelfAttention(nn.Module):\n    def __init__(self, config, has_relative_attention_bias=False):\n        super().__init__()\n        self.SelfAttention = T5Attention(\n            config, has_relative_attention_bias=has_relative_attention_bias\n        )\n        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        position_bias=None,\n        layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        normed_hidden_states = self.layer_norm(hidden_states)\n        attention_output = self.SelfAttention(\n            normed_hidden_states,\n            mask=attention_mask,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        hidden_states = hidden_states + self.dropout(attention_output[0])\n        outputs = (hidden_states,) + attention_output[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass T5LayerCrossAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)\n        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(\n        self,\n        hidden_states,\n        key_value_states,\n        attention_mask=None,\n        position_bias=None,\n        layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        query_length=None,\n        output_attentions=False,\n    ):\n        normed_hidden_states = self.layer_norm(hidden_states)\n        attention_output = self.EncDecAttention(\n            normed_hidden_states,\n            mask=attention_mask,\n            key_value_states=key_value_states,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            query_length=query_length,\n            output_attentions=output_attentions,\n        )\n        layer_output = hidden_states + self.dropout(attention_output[0])\n        outputs = (layer_output,) + attention_output[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass T5Block(nn.Module):\n    def __init__(self, config, has_relative_attention_bias=False):\n        super().__init__()\n        self.is_decoder = config.is_decoder\n        self.layer = nn.ModuleList()\n        self.layer.append(\n            T5LayerSelfAttention(\n                config, has_relative_attention_bias=has_relative_attention_bias\n            )\n        )\n        if self.is_decoder:\n            self.layer.append(T5LayerCrossAttention(config))\n\n        self.layer.append(T5LayerFF(config))\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        position_bias=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        encoder_decoder_position_bias=None,\n        layer_head_mask=None,\n        cross_attn_layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        output_attentions=False,\n        return_dict=True,\n    ):\n\n        if past_key_value is not None:\n            if not self.is_decoder:\n                logger.warning(\n                    \"`past_key_values` is passed to the encoder. Please make sure this is intended.\"\n                )\n            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4\n\n            if len(past_key_value) != expected_num_past_key_values:\n                raise ValueError(\n                    f\"There should be {expected_num_past_key_values} past states. \"\n                    f\"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}\"\n                    f\"Got {len(past_key_value)} past key / value states\"\n                )\n\n            self_attn_past_key_value = past_key_value[:2]\n            cross_attn_past_key_value = past_key_value[2:]\n        else:\n            self_attn_past_key_value, cross_attn_past_key_value = None, None\n\n        self_attention_outputs = self.layer[0](\n            hidden_states,\n            attention_mask=attention_mask,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=self_attn_past_key_value,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        hidden_states, present_key_value_state = self_attention_outputs[:2]\n        attention_outputs = self_attention_outputs[\n            2:\n        ]  # Keep self-attention outputs and relative position weights\n\n        # clamp inf values to enable fp16 training\n        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():\n            clamp_value = torch.finfo(hidden_states.dtype).max - 1000\n            hidden_states = torch.clamp(\n                hidden_states, min=-clamp_value, max=clamp_value\n            )\n\n        do_cross_attention = self.is_decoder and encoder_hidden_states is not None\n        if do_cross_attention:\n            # the actual query length is unknown for cross attention\n            # if using past key value states. Need to inject it here\n            if present_key_value_state is not None:\n                query_length = present_key_value_state[0].shape[2]\n            else:\n                query_length = None\n\n            cross_attention_outputs = self.layer[1](\n                hidden_states,\n                key_value_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                position_bias=encoder_decoder_position_bias,\n                layer_head_mask=cross_attn_layer_head_mask,\n                past_key_value=cross_attn_past_key_value,\n                query_length=query_length,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n            hidden_states = cross_attention_outputs[0]\n\n            # clamp inf values to enable fp16 training\n            if (\n                hidden_states.dtype == torch.float16\n                and torch.isinf(hidden_states).any()\n            ):\n                clamp_value = torch.finfo(hidden_states.dtype).max - 1000\n                hidden_states = torch.clamp(\n                    hidden_states, min=-clamp_value, max=clamp_value\n                )\n\n            # Combine self attn and cross attn key value states\n            if present_key_value_state is not None:\n                present_key_value_state = (\n                    present_key_value_state + cross_attention_outputs[1]\n                )\n\n            # Keep cross-attention outputs and relative position weights\n            attention_outputs = attention_outputs + cross_attention_outputs[2:]\n\n        # Apply Feed Forward layer\n        hidden_states = self.layer[-1](hidden_states)\n\n        # clamp inf values to enable fp16 training\n        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():\n            clamp_value = torch.finfo(hidden_states.dtype).max - 1000\n            hidden_states = torch.clamp(\n                hidden_states, min=-clamp_value, max=clamp_value\n            )\n\n        outputs = (hidden_states,)\n\n        if use_cache:\n            outputs = outputs + (present_key_value_state,) + attention_outputs\n        else:\n            outputs = outputs + attention_outputs\n\n        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)\n\n\nclass T5PreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = T5Config\n    load_tf_weights = load_tf_weights_in_t5\n    base_model_prefix = \"transformer\"\n    is_parallelizable = True\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"T5Block\"]\n\n    @property\n    def dummy_inputs(self):\n        input_ids = torch.tensor(DUMMY_INPUTS)\n        input_mask = torch.tensor(DUMMY_MASK)\n        dummy_inputs = {\n            \"decoder_input_ids\": input_ids,\n            \"input_ids\": input_ids,\n            \"decoder_attention_mask\": input_mask,\n        }\n        return dummy_inputs\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        factor = (\n            self.config.initializer_factor\n        )  # Used for testing weights initialization\n        if isinstance(module, T5LayerNorm):\n            module.weight.data.fill_(factor * 1.0)\n        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):\n            # Mesh TensorFlow embeddings initialization\n            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624\n            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)\n            if hasattr(module, \"lm_head\") and not self.config.tie_word_embeddings:\n                module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)\n        elif isinstance(module, T5DenseActDense):\n            # Mesh TensorFlow FF initialization\n            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56\n            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89\n            module.wi.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)\n            )\n            if hasattr(module.wi, \"bias\") and module.wi.bias is not None:\n                module.wi.bias.data.zero_()\n            module.wo.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)\n            )\n            if hasattr(module.wo, \"bias\") and module.wo.bias is not None:\n                module.wo.bias.data.zero_()\n        elif isinstance(module, T5DenseGatedActDense):\n            module.wi_0.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)\n            )\n            if hasattr(module.wi_0, \"bias\") and module.wi_0.bias is not None:\n                module.wi_0.bias.data.zero_()\n            module.wi_1.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)\n            )\n            if hasattr(module.wi_1, \"bias\") and module.wi_1.bias is not None:\n                module.wi_1.bias.data.zero_()\n            module.wo.weight.data.normal_(\n                mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)\n            )\n            if hasattr(module.wo, \"bias\") and module.wo.bias is not None:\n                module.wo.bias.data.zero_()\n        elif isinstance(module, T5Attention):\n            # Mesh TensorFlow attention initialization to avoid scaling before softmax\n            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136\n            d_model = self.config.d_model\n            key_value_proj_dim = self.config.d_kv\n            n_heads = self.config.num_heads\n            module.q.weight.data.normal_(\n                mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)\n            )\n            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))\n            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))\n            module.o.weight.data.normal_(\n                mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)\n            )\n            if module.has_relative_attention_bias:\n                module.relative_attention_bias.weight.data.normal_(\n                    mean=0.0, std=factor * ((d_model) ** -0.5)\n                )\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (T5Attention, T5Stack)):\n            module.gradient_checkpointing = value\n\n    def _shift_right(self, input_ids):\n        decoder_start_token_id = self.config.decoder_start_token_id\n        pad_token_id = self.config.pad_token_id\n\n        assert decoder_start_token_id is not None, (\n            \"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.\"\n            \" See T5 docs for more information\"\n        )\n\n        # shift inputs to the right\n        if is_torch_fx_proxy(input_ids):\n            # Item assignment is not supported natively for proxies.\n            shifted_input_ids = torch.full(\n                input_ids.shape[:-1] + (1,), decoder_start_token_id\n            )\n            shifted_input_ids = torch.cat(\n                [shifted_input_ids, input_ids[..., :-1]], dim=-1\n            )\n        else:\n            shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()\n            shifted_input_ids[..., 0] = decoder_start_token_id\n\n        assert (\n            pad_token_id is not None\n        ), \"self.model.config.pad_token_id has to be defined.\"\n        # replace possible -100 values in labels by `pad_token_id`\n        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n\n        return shifted_input_ids\n\n\nclass T5Stack(T5PreTrainedModel):\n    def __init__(self, config, embed_tokens=None):\n        super().__init__(config)\n\n        self.embed_tokens = embed_tokens\n        self.is_decoder = config.is_decoder\n\n        self.block = nn.ModuleList(\n            [\n                T5Block(config, has_relative_attention_bias=bool(i == 0))\n                for i in range(config.num_layers)\n            ]\n        )\n        self.final_layer_norm = T5LayerNorm(\n            config.d_model, eps=config.layer_norm_epsilon\n        )\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n        self.gradient_checkpointing = False\n\n    @add_start_docstrings(PARALLELIZE_DOCSTRING)\n    def parallelize(self, device_map=None):\n        # Check validity of device_map\n        self.device_map = (\n            get_device_map(len(self.block), range(torch.cuda.device_count()))\n            if device_map is None\n            else device_map\n        )\n        assert_device_map(self.device_map, len(self.block))\n        self.model_parallel = True\n        self.first_device = (\n            \"cpu\"\n            if \"cpu\" in self.device_map.keys()\n            else \"cuda:\" + str(min(self.device_map.keys()))\n        )\n        self.last_device = \"cuda:\" + str(max(self.device_map.keys()))\n        # Load onto devices\n        for k, v in self.device_map.items():\n            for layer in v:\n                cuda_device = \"cuda:\" + str(k)\n                self.block[layer] = self.block[layer].to(cuda_device)\n\n        # Set embed_tokens to first layer\n        self.embed_tokens = self.embed_tokens.to(self.first_device)\n        # Set final layer norm to last device\n        self.final_layer_norm = self.final_layer_norm.to(self.last_device)\n\n    @add_start_docstrings(PARALLELIZE_DOCSTRING)\n    def deparallelize(self):\n        self.model_parallel = False\n        self.device_map = None\n        self.first_device = \"cpu\"\n        self.last_device = \"cpu\"\n        for i in range(len(self.block)):\n            self.block[i] = self.block[i].to(\"cpu\")\n        self.embed_tokens = self.embed_tokens.to(\"cpu\")\n        self.final_layer_norm = self.final_layer_norm.to(\"cpu\")\n        torch.cuda.empty_cache()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, new_embeddings):\n        self.embed_tokens = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        inputs_embeds=None,\n        head_mask=None,\n        cross_attn_head_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        # Model parallel\n        if self.model_parallel:\n            torch.cuda.set_device(self.first_device)\n            self.embed_tokens = self.embed_tokens.to(self.first_device)\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            err_msg_prefix = \"decoder_\" if self.is_decoder else \"\"\n            raise ValueError(\n                f\"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            err_msg_prefix = \"decoder_\" if self.is_decoder else \"\"\n            raise ValueError(\n                f\"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds\"\n            )\n\n        if inputs_embeds is None:\n            assert (\n                self.embed_tokens is not None\n            ), \"You have to initialize the model with valid token embeddings\"\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        batch_size, seq_length = input_shape\n\n        # required mask seq length can be calculated via length of past\n        mask_seq_length = (\n            past_key_values[0][0].shape[2] + seq_length\n            if past_key_values is not None\n            else seq_length\n        )\n\n        if use_cache is True:\n            assert (\n                self.is_decoder\n            ), f\"`use_cache` can only be set to `True` if {self} is used as a decoder\"\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                batch_size, mask_seq_length, device=inputs_embeds.device\n            )\n        if (\n            self.is_decoder\n            and encoder_attention_mask is None\n            and encoder_hidden_states is not None\n        ):\n            encoder_seq_length = encoder_hidden_states.shape[1]\n            encoder_attention_mask = torch.ones(\n                batch_size,\n                encoder_seq_length,\n                device=inputs_embeds.device,\n                dtype=torch.long,\n            )\n\n        # initialize past_key_values with `None` if past does not exist\n        if past_key_values is None:\n            past_key_values = [None] * len(self.block)\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask = self.get_extended_attention_mask(\n            attention_mask, input_shape\n        )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if self.is_decoder and encoder_hidden_states is not None:\n            (\n                encoder_batch_size,\n                encoder_sequence_length,\n                _,\n            ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n            if encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(\n                    encoder_hidden_shape, device=inputs_embeds.device\n                )\n            encoder_extended_attention_mask = self.invert_attention_mask(\n                encoder_attention_mask\n            )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_layers)\n        cross_attn_head_mask = self.get_head_mask(\n            cross_attn_head_mask, self.config.num_layers\n        )\n        present_key_value_states = () if use_cache else None\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n        all_cross_attentions = () if (output_attentions and self.is_decoder) else None\n        position_bias = None\n        encoder_decoder_position_bias = None\n\n        hidden_states = self.dropout(inputs_embeds)\n\n        for i, (layer_module, past_key_value) in enumerate(\n            zip(self.block, past_key_values)\n        ):\n            layer_head_mask = head_mask[i]\n            cross_attn_layer_head_mask = cross_attn_head_mask[i]\n            # Model parallel\n            if self.model_parallel:\n                torch.cuda.set_device(hidden_states.device)\n                # Ensure that attention_mask is always on the same device as hidden_states\n                if attention_mask is not None:\n                    attention_mask = attention_mask.to(hidden_states.device)\n                if position_bias is not None:\n                    position_bias = position_bias.to(hidden_states.device)\n                if encoder_hidden_states is not None:\n                    encoder_hidden_states = encoder_hidden_states.to(\n                        hidden_states.device\n                    )\n                if encoder_extended_attention_mask is not None:\n                    encoder_extended_attention_mask = (\n                        encoder_extended_attention_mask.to(hidden_states.device)\n                    )\n                if encoder_decoder_position_bias is not None:\n                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(\n                        hidden_states.device\n                    )\n                if layer_head_mask is not None:\n                    layer_head_mask = layer_head_mask.to(hidden_states.device)\n                if cross_attn_layer_head_mask is not None:\n                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(\n                        hidden_states.device\n                    )\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return tuple(module(*inputs, use_cache, output_attentions))\n\n                    return custom_forward\n\n                layer_outputs = checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    extended_attention_mask,\n                    position_bias,\n                    encoder_hidden_states,\n                    encoder_extended_attention_mask,\n                    encoder_decoder_position_bias,\n                    layer_head_mask,\n                    cross_attn_layer_head_mask,\n                    None,  # past_key_value is always None with gradient checkpointing\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask=extended_attention_mask,\n                    position_bias=position_bias,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_extended_attention_mask,\n                    encoder_decoder_position_bias=encoder_decoder_position_bias,\n                    layer_head_mask=layer_head_mask,\n                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,\n                    past_key_value=past_key_value,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            # layer_outputs is a tuple with:\n            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)\n            if use_cache is False:\n                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]\n\n            hidden_states, present_key_value_state = layer_outputs[:2]\n\n            # We share the position biases between the layers - the first layer store them\n            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),\n            # (cross-attention position bias), (cross-attention weights)\n            position_bias = layer_outputs[2]\n            if self.is_decoder and encoder_hidden_states is not None:\n                encoder_decoder_position_bias = layer_outputs[\n                    4 if output_attentions else 3\n                ]\n            # append next layer key value states\n            if use_cache:\n                present_key_value_states = present_key_value_states + (\n                    present_key_value_state,\n                )\n\n            if output_attentions:\n                all_attentions = all_attentions + (layer_outputs[3],)\n                if self.is_decoder:\n                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)\n\n            # Model Parallel: If it's the last layer for that device, put things on the next device\n            if self.model_parallel:\n                for k, v in self.device_map.items():\n                    if i == v[-1] and \"cuda:\" + str(k) != self.last_device:\n                        hidden_states = hidden_states.to(\"cuda:\" + str(k + 1))\n\n        hidden_states = self.final_layer_norm(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    present_key_value_states,\n                    all_hidden_states,\n                    all_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=present_key_value_states,\n            hidden_states=all_hidden_states,\n            attentions=all_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nT5_START_DOCSTRING = r\"\"\"\n\n    The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text\n    Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan\n    Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a\n    text-to-text denoising generative setting.\n\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`T5Config`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nT5_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you\n            should be able to pad the inputs on both the right and the left.\n\n            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for detail.\n\n            [What are input IDs?](../glossary#input-ids)\n\n            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).\n        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):\n            Indices of decoder input sequence tokens in the vocabulary.\n\n            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are decoder input IDs?](../glossary#decoder-input-ids)\n\n            T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`\n            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).\n\n            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5\n            Training](./t5#training).\n        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):\n            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also\n            be used by default.\n        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,\n            1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,\n            1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in\n                `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):\n            Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)\n            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at\n            the output of the last layer of the encoder. Used in the cross-attention of the decoder.\n        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded\n            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be\n            input (see `past_key_values`). This is useful if you want more control over how to convert\n            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.\n\n            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value\n            of `inputs_embeds`.\n\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\nT5_ENCODER_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you\n            should be able to pad the inputs on both the right and the left.\n\n            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for detail.\n\n            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).\n        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask\n__HEAD_MASK_WARNING_MSG = \"\"\"\nThe input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,\n`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.\nIf you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,\nnum_heads)`.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.\",\n    T5_START_DOCSTRING,\n)\nclass T5Model(T5PreTrainedModel):\n    _keys_to_ignore_on_load_missing = [\n        r\"encoder.embed_tokens.weight\",\n        r\"decoder.embed_tokens.weight\",\n    ]\n    _keys_to_ignore_on_load_unexpected = [\n        r\"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight\",\n    ]\n\n    def __init__(self, config: T5Config):\n        super().__init__(config)\n        self.shared = nn.Embedding(config.vocab_size, config.d_model)\n\n        encoder_config = copy.deepcopy(config)\n        encoder_config.is_decoder = False\n        encoder_config.use_cache = False\n        encoder_config.is_encoder_decoder = False\n        self.encoder = T5Stack(encoder_config, self.shared)\n\n        decoder_config = copy.deepcopy(config)\n        decoder_config.is_decoder = True\n        decoder_config.is_encoder_decoder = False\n        decoder_config.num_layers = config.num_decoder_layers\n        self.decoder = T5Stack(decoder_config, self.shared)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n    @add_start_docstrings(PARALLELIZE_DOCSTRING)\n    def parallelize(self, device_map=None):\n        self.device_map = (\n            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))\n            if device_map is None\n            else device_map\n        )\n        assert_device_map(self.device_map, len(self.encoder.block))\n        self.encoder.parallelize(self.device_map)\n        self.decoder.parallelize(self.device_map)\n        self.model_parallel = True\n\n    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)\n    def deparallelize(self):\n        self.encoder.deparallelize()\n        self.decoder.deparallelize()\n        self.encoder = self.encoder.to(\"cpu\")\n        self.decoder = self.decoder.to(\"cpu\")\n        self.model_parallel = False\n        self.device_map = None\n        torch.cuda.empty_cache()\n\n    def get_input_embeddings(self):\n        return self.shared\n\n    def set_input_embeddings(self, new_embeddings):\n        self.shared = new_embeddings\n        self.encoder.set_input_embeddings(new_embeddings)\n        self.decoder.set_input_embeddings(new_embeddings)\n\n    def get_encoder(self):\n        return self.encoder\n\n    def get_decoder(self):\n        return self.decoder\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        decoder_input_ids: Optional[torch.LongTensor] = None,\n        decoder_attention_mask: Optional[torch.BoolTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        decoder_head_mask: Optional[torch.FloatTensor] = None,\n        cross_attn_head_mask: Optional[torch.Tensor] = None,\n        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        decoder_inputs_embeds: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:\n        r\"\"\"\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import T5Tokenizer, T5Model\n\n        >>> tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n        >>> model = T5Model.from_pretrained(\"t5-small\")\n\n        >>> input_ids = tokenizer(\n        ...     \"Studies have been shown that owning a dog is good for you\", return_tensors=\"pt\"\n        ... ).input_ids  # Batch size 1\n        >>> decoder_input_ids = tokenizer(\"Studies show that\", return_tensors=\"pt\").input_ids  # Batch size 1\n\n        >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.\n        >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.\n        >>> decoder_input_ids = model._shift_right(decoder_input_ids)\n\n        >>> # forward pass\n        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)\n        >>> last_hidden_states = outputs.last_hidden_state\n        ```\"\"\"\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask\n        if head_mask is not None and decoder_head_mask is None:\n            if self.config.num_layers == self.config.num_decoder_layers:\n                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)\n                decoder_head_mask = head_mask\n\n        # Encode if needed (training, first prediction pass)\n        if encoder_outputs is None:\n            encoder_outputs = self.encoder(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                inputs_embeds=inputs_embeds,\n                head_mask=head_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):\n            encoder_outputs = BaseModelOutput(\n                last_hidden_state=encoder_outputs[0],\n                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,\n                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,\n            )\n\n        hidden_states = encoder_outputs[0]\n\n        # Set device for model parallelism\n        if self.model_parallel:\n            torch.cuda.set_device(self.decoder.first_device)\n            hidden_states = hidden_states.to(self.decoder.first_device)\n            if decoder_input_ids is not None:\n                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(self.decoder.first_device)\n            if decoder_attention_mask is not None:\n                decoder_attention_mask = decoder_attention_mask.to(\n                    self.decoder.first_device\n                )\n\n        # Decode\n        decoder_outputs = self.decoder(\n            input_ids=decoder_input_ids,\n            attention_mask=decoder_attention_mask,\n            inputs_embeds=decoder_inputs_embeds,\n            past_key_values=past_key_values,\n            encoder_hidden_states=hidden_states,\n            encoder_attention_mask=attention_mask,\n            head_mask=decoder_head_mask,\n            cross_attn_head_mask=cross_attn_head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        if not return_dict:\n            return decoder_outputs + encoder_outputs\n\n        return Seq2SeqModelOutput(\n            last_hidden_state=decoder_outputs.last_hidden_state,\n            past_key_values=decoder_outputs.past_key_values,\n            decoder_hidden_states=decoder_outputs.hidden_states,\n            decoder_attentions=decoder_outputs.attentions,\n            cross_attentions=decoder_outputs.cross_attentions,\n            encoder_last_hidden_state=encoder_outputs.last_hidden_state,\n            encoder_hidden_states=encoder_outputs.hidden_states,\n            encoder_attentions=encoder_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"T5 Model with a `language modeling` head on top.\"\"\", T5_START_DOCSTRING\n)\nclass T5ForConditionalGeneration(T5PreTrainedModel):\n    _keys_to_ignore_on_load_missing = [\n        r\"encoder.embed_tokens.weight\",\n        r\"decoder.embed_tokens.weight\",\n        r\"lm_head.weight\",\n    ]\n    _keys_to_ignore_on_load_unexpected = [\n        r\"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight\",\n    ]\n\n    def __init__(self, config: T5Config):\n        super().__init__(config)\n        self.model_dim = config.d_model\n\n        self.shared = nn.Embedding(config.vocab_size, config.d_model)\n\n        encoder_config = copy.deepcopy(config)\n        encoder_config.is_decoder = False\n        encoder_config.use_cache = False\n        encoder_config.is_encoder_decoder = False\n        self.encoder = T5Stack(encoder_config, self.shared)\n\n        decoder_config = copy.deepcopy(config)\n        decoder_config.is_decoder = True\n        decoder_config.is_encoder_decoder = False\n        decoder_config.num_layers = config.num_decoder_layers\n        self.decoder = T5Stack(decoder_config, self.shared)\n\n        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n    @add_start_docstrings(PARALLELIZE_DOCSTRING)\n    def parallelize(self, device_map=None):\n        self.device_map = (\n            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))\n            if device_map is None\n            else device_map\n        )\n        assert_device_map(self.device_map, len(self.encoder.block))\n        self.encoder.parallelize(self.device_map)\n        self.decoder.parallelize(self.device_map)\n        self.lm_head = self.lm_head.to(self.decoder.first_device)\n        self.model_parallel = True\n\n    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)\n    def deparallelize(self):\n        self.encoder.deparallelize()\n        self.decoder.deparallelize()\n        self.encoder = self.encoder.to(\"cpu\")\n        self.decoder = self.decoder.to(\"cpu\")\n        self.lm_head = self.lm_head.to(\"cpu\")\n        self.model_parallel = False\n        self.device_map = None\n        torch.cuda.empty_cache()\n\n    def get_input_embeddings(self):\n        return self.shared\n\n    def set_input_embeddings(self, new_embeddings):\n        self.shared = new_embeddings\n        self.encoder.set_input_embeddings(new_embeddings)\n        self.decoder.set_input_embeddings(new_embeddings)\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def get_encoder(self):\n        return self.encoder\n\n    def get_decoder(self):\n        return self.decoder\n\n    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        decoder_input_ids: Optional[torch.LongTensor] = None,\n        decoder_attention_mask: Optional[torch.BoolTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        decoder_head_mask: Optional[torch.FloatTensor] = None,\n        cross_attn_head_mask: Optional[torch.Tensor] = None,\n        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        reduction: Optional[str] = \"mean\",\n    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,\n            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for\n            labels in `[0, ..., config.vocab_size]`\n\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration\n\n        >>> tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n        >>> model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n\n        >>> # training\n        >>> input_ids = tokenizer(\"The <extra_id_0> walks in <extra_id_1> park\", return_tensors=\"pt\").input_ids\n        >>> labels = tokenizer(\"<extra_id_0> cute dog <extra_id_1> the <extra_id_2>\", return_tensors=\"pt\").input_ids\n        >>> outputs = model(input_ids=input_ids, labels=labels)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\n        >>> # inference\n        >>> input_ids = tokenizer(\n        ...     \"summarize: studies have shown that owning a dog is good for you\", return_tensors=\"pt\"\n        ... ).input_ids  # Batch size 1\n        >>> outputs = model.generate(input_ids)\n        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n        >>> # studies have shown that owning a dog is good for you.\n        ```\"\"\"\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask\n        if head_mask is not None and decoder_head_mask is None:\n            if self.config.num_layers == self.config.num_decoder_layers:\n                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)\n                decoder_head_mask = head_mask\n\n        # Encode if needed (training, first prediction pass)\n        if encoder_outputs is None:\n            # Convert encoder inputs in embeddings if needed\n            encoder_outputs = self.encoder(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                inputs_embeds=inputs_embeds,\n                head_mask=head_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):\n            encoder_outputs = BaseModelOutput(\n                last_hidden_state=encoder_outputs[0],\n                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,\n                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,\n            )\n\n        hidden_states = encoder_outputs[0]\n\n        if self.model_parallel:\n            torch.cuda.set_device(self.decoder.first_device)\n\n        if (\n            labels is not None\n            and decoder_input_ids is None\n            and decoder_inputs_embeds is None\n        ):\n            # get decoder inputs from shifting lm labels to the right\n            decoder_input_ids = self._shift_right(labels)\n\n        # Set device for model parallelism\n        if self.model_parallel:\n            torch.cuda.set_device(self.decoder.first_device)\n            hidden_states = hidden_states.to(self.decoder.first_device)\n            if decoder_input_ids is not None:\n                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(self.decoder.first_device)\n            if decoder_attention_mask is not None:\n                decoder_attention_mask = decoder_attention_mask.to(\n                    self.decoder.first_device\n                )\n\n        # Decode\n        decoder_outputs = self.decoder(\n            input_ids=decoder_input_ids,\n            attention_mask=decoder_attention_mask,\n            inputs_embeds=decoder_inputs_embeds,\n            past_key_values=past_key_values,\n            encoder_hidden_states=hidden_states,\n            encoder_attention_mask=attention_mask,\n            head_mask=decoder_head_mask,\n            cross_attn_head_mask=cross_attn_head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = decoder_outputs[0]\n\n        # Set device for model parallelism\n        if self.model_parallel:\n            torch.cuda.set_device(self.encoder.first_device)\n            self.lm_head = self.lm_head.to(self.encoder.first_device)\n            sequence_output = sequence_output.to(self.lm_head.weight.device)\n\n        if self.config.tie_word_embeddings:\n            # Rescale output before projecting on vocab\n            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586\n            sequence_output = sequence_output * (self.model_dim**-0.5)\n\n        lm_logits = self.lm_head(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction)\n            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))\n            if reduction == \"none\":\n                loss = loss.view(lm_logits.size(0), -1).sum(1)\n\n        if not return_dict:\n            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs\n            return ((loss,) + output) if loss is not None else output\n\n        return Seq2SeqLMOutput(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=decoder_outputs.past_key_values,\n            decoder_hidden_states=decoder_outputs.hidden_states,\n            decoder_attentions=decoder_outputs.attentions,\n            cross_attentions=decoder_outputs.cross_attentions,\n            encoder_last_hidden_state=encoder_outputs.last_hidden_state,\n            encoder_hidden_states=encoder_outputs.hidden_states,\n            encoder_attentions=encoder_outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past=None,\n        attention_mask=None,\n        head_mask=None,\n        decoder_head_mask=None,\n        cross_attn_head_mask=None,\n        use_cache=None,\n        encoder_outputs=None,\n        **kwargs,\n    ):\n\n        # cut decoder_input_ids if past is used\n        if past is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"decoder_input_ids\": input_ids,\n            \"past_key_values\": past,\n            \"encoder_outputs\": encoder_outputs,\n            \"attention_mask\": attention_mask,\n            \"head_mask\": head_mask,\n            \"decoder_head_mask\": decoder_head_mask,\n            \"cross_attn_head_mask\": cross_attn_head_mask,\n            \"use_cache\": use_cache,\n        }\n\n    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):\n        return self._shift_right(labels)\n\n    def _reorder_cache(self, past, beam_idx):\n        # if decoder past is not included in output\n        # speedy decoding is disabled and no need to reorder\n        if past is None:\n            logger.warning(\n                \"You might want to consider setting `use_cache=True` to speed up decoding\"\n            )\n            return past\n\n        reordered_decoder_past = ()\n        for layer_past_states in past:\n            # get the correct batch idx from layer past batch dim\n            # batch dim of `past` is at 2nd position\n            reordered_layer_past_states = ()\n            for layer_past_state in layer_past_states:\n                # need to set correct `past` for each of the four key / value states\n                reordered_layer_past_states = reordered_layer_past_states + (\n                    layer_past_state.index_select(\n                        0, beam_idx.to(layer_past_state.device)\n                    ),\n                )\n\n            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape\n            assert len(reordered_layer_past_states) == len(layer_past_states)\n\n            reordered_decoder_past = reordered_decoder_past + (\n                reordered_layer_past_states,\n            )\n        return reordered_decoder_past\n\n\n@add_start_docstrings(\n    \"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.\",\n    T5_START_DOCSTRING,\n)\nclass T5EncoderModel(T5PreTrainedModel):\n    authorized_missing_keys = [\n        r\"encoder.embed_tokens.weight\",\n    ]\n\n    def __init__(self, config: T5Config):\n        super().__init__(config)\n        self.shared = nn.Embedding(config.vocab_size, config.d_model)\n\n        encoder_config = copy.deepcopy(config)\n        encoder_config.use_cache = False\n        encoder_config.is_encoder_decoder = False\n        self.encoder = T5Stack(encoder_config, self.shared)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n    @add_start_docstrings(PARALLELIZE_DOCSTRING)\n    def parallelize(self, device_map=None):\n        self.device_map = (\n            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))\n            if device_map is None\n            else device_map\n        )\n        assert_device_map(self.device_map, len(self.encoder.block))\n        self.encoder.parallelize(self.device_map)\n        self.model_parallel = True\n\n    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)\n    def deparallelize(self):\n        self.encoder.deparallelize()\n        self.encoder = self.encoder.to(\"cpu\")\n        self.model_parallel = False\n        self.device_map = None\n        torch.cuda.empty_cache()\n\n    def get_input_embeddings(self):\n        return self.shared\n\n    def set_input_embeddings(self, new_embeddings):\n        self.shared = new_embeddings\n        self.encoder.set_input_embeddings(new_embeddings)\n\n    def get_encoder(self):\n        return self.encoder\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:\n        r\"\"\"\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import T5Tokenizer, T5EncoderModel\n\n        >>> tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n        >>> model = T5EncoderModel.from_pretrained(\"t5-small\")\n        >>> input_ids = tokenizer(\n        ...     \"Studies have been shown that owning a dog is good for you\", return_tensors=\"pt\"\n        ... ).input_ids  # Batch size 1\n        >>> outputs = model(input_ids=input_ids)\n        >>> last_hidden_states = outputs.last_hidden_state\n        ```\"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        encoder_outputs = self.encoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            inputs_embeds=inputs_embeds,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        return encoder_outputs\n"
  },
  {
    "path": "lavis/models/blip_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nfrom typing import List\n\nfrom torch import nn\n\n\ndef tie_encoder_decoder_weights(\n    encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str\n):\n    uninitialized_encoder_weights: List[str] = []\n    if decoder.__class__ != encoder.__class__:\n        logging.info(\n            f\"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized.\"\n        )\n\n    def tie_encoder_to_decoder_recursively(\n        decoder_pointer: nn.Module,\n        encoder_pointer: nn.Module,\n        module_name: str,\n        uninitialized_encoder_weights: List[str],\n        skip_key: str,\n        depth=0,\n    ):\n        assert isinstance(decoder_pointer, nn.Module) and isinstance(\n            encoder_pointer, nn.Module\n        ), f\"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module\"\n        if hasattr(decoder_pointer, \"weight\") and skip_key not in module_name:\n            assert hasattr(encoder_pointer, \"weight\")\n            encoder_pointer.weight = decoder_pointer.weight\n            if hasattr(decoder_pointer, \"bias\"):\n                assert hasattr(encoder_pointer, \"bias\")\n                encoder_pointer.bias = decoder_pointer.bias\n            print(module_name + \" is tied\")\n            return\n\n        encoder_modules = encoder_pointer._modules\n        decoder_modules = decoder_pointer._modules\n        if len(decoder_modules) > 0:\n            assert (\n                len(encoder_modules) > 0\n            ), f\"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}\"\n\n            all_encoder_weights = set(\n                [module_name + \"/\" + sub_name for sub_name in encoder_modules.keys()]\n            )\n            encoder_layer_pos = 0\n            for name, module in decoder_modules.items():\n                if name.isdigit():\n                    encoder_name = str(int(name) + encoder_layer_pos)\n                    decoder_name = name\n                    if not isinstance(\n                        decoder_modules[decoder_name],\n                        type(encoder_modules[encoder_name]),\n                    ) and len(encoder_modules) != len(decoder_modules):\n                        # this can happen if the name corresponds to the position in a list module list of layers\n                        # in this case the decoder has added a cross-attention that the encoder does not have\n                        # thus skip this step and subtract one layer pos from encoder\n                        encoder_layer_pos -= 1\n                        continue\n                elif name not in encoder_modules:\n                    continue\n                elif depth > 500:\n                    raise ValueError(\n                        \"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model.\"\n                    )\n                else:\n                    decoder_name = encoder_name = name\n                tie_encoder_to_decoder_recursively(\n                    decoder_modules[decoder_name],\n                    encoder_modules[encoder_name],\n                    module_name + \"/\" + name,\n                    uninitialized_encoder_weights,\n                    skip_key,\n                    depth=depth + 1,\n                )\n                all_encoder_weights.remove(module_name + \"/\" + encoder_name)\n\n            uninitialized_encoder_weights += list(all_encoder_weights)\n\n    # tie weights recursively\n    tie_encoder_to_decoder_recursively(\n        decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key\n    )\n"
  },
  {
    "path": "lavis/models/blip_models/blip.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.utils import is_url\nfrom lavis.models.base_model import BaseModel\nfrom lavis.models.vit import interpolate_pos_embed\nfrom transformers import BertTokenizer\n\n\nclass BlipBase(BaseModel):\n    @classmethod\n    def init_tokenizer(cls):\n        tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n        tokenizer.add_special_tokens({\"bos_token\": \"[DEC]\"})\n        tokenizer.add_special_tokens({\"additional_special_tokens\": [\"[ENC]\"]})\n        tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]\n        return tokenizer\n\n    def load_from_pretrained(self, url_or_filename):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n\n        state_dict[\"visual_encoder.pos_embed\"] = interpolate_pos_embed(\n            state_dict[\"visual_encoder.pos_embed\"], self.visual_encoder\n        )\n        if \"visual_encoder_m.pos_embed\" in self.state_dict().keys():\n            state_dict[\"visual_encoder_m.pos_embed\"] = interpolate_pos_embed(\n                state_dict[\"visual_encoder_m.pos_embed\"], self.visual_encoder_m\n            )\n\n        for key in self.state_dict().keys():\n            if key in state_dict.keys():\n                if state_dict[key].shape != self.state_dict()[key].shape:\n                    del state_dict[key]\n\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n"
  },
  {
    "path": "lavis/models/blip_models/blip_caption.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nfrom lavis.common.registry import registry\n\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import (\n    BlipOutput,\n    BlipIntermediateOutput,\n)\nfrom lavis.models.med import XBertLMHeadDecoder\nfrom lavis.models.vit import VisionTransformerEncoder\n\n\n@registry.register_model(\"blip_caption\")\nclass BlipCaption(BlipBase):\n    \"\"\"\n    BLIP captioning model.\n\n    Supported model types:\n        - base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split).\n        - large_coco: fine-tuned BLIP large model on COCO caption dataset (Karparthy split).\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_caption\", \"base_coco\")\n        >>> model = load_model(\"blip_caption\", \"large_coco\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base_coco\": \"configs/models/blip_caption_base_coco.yaml\",\n        \"large_coco\": \"configs/models/blip_caption_large_coco.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_decoder, prompt=None, max_txt_len=40):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n        self.text_decoder = text_decoder\n\n        self.prompt = prompt\n        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1\n\n        self.max_txt_len = max_txt_len\n\n    def forward_encoder(self, samples):\n        image_embeds = self.visual_encoder.forward_features(samples[\"image\"])\n        return image_embeds\n\n    def forward_decoder(self, samples, image_embeds):\n        # prepare inputs for forwarding decoder\n        raw_text = samples[\"text_input\"]\n        text = self.tokenizer(\n            raw_text,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        text.input_ids[:, 0] = self.tokenizer.bos_token_id\n\n        # prepare targets for forwarding decoder\n        decoder_targets = text.input_ids.masked_fill(\n            text.input_ids == self.tokenizer.pad_token_id, -100\n        )\n        decoder_targets[:, : self.prompt_length] = -100\n\n        # forward decoder\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n        decoder_output = self.text_decoder(\n            input_ids=text.input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            labels=decoder_targets,\n            return_dict=True,\n        )\n\n        return decoder_output, decoder_targets\n\n    def forward(self, samples):\n        r\"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size.\n        Returns:\n            output (BlipOutput): A BlipOutput object containing the following\n                attributes:\n                - loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss.\n                - loss_lm (torch.Tensor): A scalar tensor containing the LM loss.\n                - intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs.\n                  see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details.\n\n        Example:\n        ```python\n        >>> from PIL import Image\n        >>> from lavis.models import load_model_and_preprocess\n        >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"blip_caption\")\n        >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n        >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n        >>> text_input = [\"a large statue of a person spraying water from a fountain\"]\n        >>> samples = {\"image\": image, \"text_input\": text_input}\n        >>> output = model(samples)\n        >>> output.keys()\n        odict_keys(['intermediate_output', 'loss', 'loss_lm'])\n        >>> output.intermediate_output.image_embeds.shape\n        torch.Size([1, 577, 768])\n        >>> output.intermediate_output.decoder_labels.shape\n        torch.Size([1, 13])\n        ```\"\"\"\n\n        image_embeds = self.forward_encoder(samples)\n        decoder_output, decoder_targets = self.forward_decoder(samples, image_embeds)\n\n        # return decoder_out\n        return BlipOutput(\n            loss=decoder_output.loss,\n            loss_lm=decoder_output.loss,\n            intermediate_output=BlipIntermediateOutput(\n                image_embeds=image_embeds,\n                decoder_output=decoder_output,\n                decoder_labels=decoder_targets,\n            ),\n        )\n\n    def generate(\n        self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=3,\n        max_length=30,\n        min_length=10,\n        top_p=0.9,\n        repetition_penalty=1.0,\n        num_captions=1,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n\n        Example:\n        ```python\n        >>> from PIL import Image\n        >>> from lavis.models import load_model_and_preprocess\n        >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"blip_caption\")\n        >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n        >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n        >>> samples = {\"image\": image}\n        >>> captions = model.generate(samples)\n        >>> captions\n        ['a large statue of a person spraying water from a fountain']\n        >>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3)\n        >>> captions # example output, results may vary due to randomness\n        ['singapore showing the view of some building',\n        'the singapore harbor in twilight, as the weather is going down',\n        'the famous singapore fountain at sunset']\n        \"\"\"\n        # prepare inputs for decoder generation.\n        encoder_out = self.forward_encoder(samples)\n        image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0)\n\n        prompt = [self.prompt] * image_embeds.size(0)\n        prompt = self.tokenizer(prompt, return_tensors=\"pt\").to(self.device)\n        prompt.input_ids[:, 0] = self.tokenizer.bos_token_id\n        prompt.input_ids = prompt.input_ids[:, :-1]\n\n        # get decoded text\n        decoder_out = self.text_decoder.generate_from_encoder(\n            tokenized_prompt=prompt,\n            visual_embeds=image_embeds,\n            sep_token_id=self.tokenizer.sep_token_id,\n            pad_token_id=self.tokenizer.pad_token_id,\n            use_nucleus_sampling=use_nucleus_sampling,\n            num_beams=num_beams,\n            max_length=max_length,\n            min_length=min_length,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n        )\n\n        outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)\n        captions = [output[len(self.prompt) :] for output in outputs]\n\n        return captions\n\n    @classmethod\n    def from_config(cls, cfg):\n        # vision encoder\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n        # text encoder + multimodal decoder\n        text_decoder = XBertLMHeadDecoder.from_config(cfg)\n\n        prompt = cfg.get(\"prompt\", None)\n        max_txt_len = cfg.get(\"max_txt_len\", 40)\n\n        model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len)\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip_models/blip_classification.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import MomentumDistilationMixin\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import (\n    BlipIntermediateOutput,\n    BlipOutputWithLogits,\n)\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"blip_classification\")\nclass BlipClassification(BlipBase, MomentumDistilationMixin):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/blip_classification_base.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        num_classes,\n        momentum=0.995,\n        alpha=0.4,\n        max_txt_len=40,\n        use_distill=True,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.use_distill = use_distill\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        hidden_size = text_encoder.config.hidden_size\n        self.cls_head = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, num_classes),\n        )\n\n        if self.use_distill:\n            self.visual_encoder_m = deepcopy(self.visual_encoder)\n            self.text_encoder_m = deepcopy(self.text_encoder)\n            self.cls_head_m = deepcopy(self.cls_head)\n\n            self.momentum = momentum\n            self.alpha = alpha\n\n            self.model_pairs = [\n                [self.visual_encoder, self.visual_encoder_m],\n                [self.text_encoder, self.text_encoder_m],\n                [self.cls_head, self.cls_head_m],\n            ]\n\n            self.copy_params()\n\n        self.max_txt_len = max_txt_len\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)\n\n    def forward(self, samples, is_train=True):\n        sentences = samples[\"text_input\"]\n        sentences = self.tokenizer(\n            sentences,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        samples.update({\"tokenized_text\": sentences})\n\n        targets = samples[\"label\"]\n\n        image_embeds = self.visual_encoder.forward_features(samples[\"image\"])\n        encoder_output = self.text_encoder.forward_automask(\n            samples[\"tokenized_text\"], image_embeds\n        )\n\n        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])\n\n        if is_train:\n            if self.use_distill:\n                with torch.no_grad():\n                    self._momentum_update()\n\n                    image_embeds_m = self.visual_encoder_m(samples[\"image\"])\n                    encoder_output_m = self.text_encoder_m.forward_automask(\n                        samples[\"tokenized_text\"], image_embeds_m\n                    )\n\n                    prediction_m = self.cls_head_m(\n                        encoder_output_m.last_hidden_state[:, 0, :]\n                    )\n\n                alpha = self.alpha * self._rampup_factor(\n                    epoch=samples[\"epoch\"],\n                    iters=samples[\"iters\"],\n                    num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n                )\n\n                loss = (1 - alpha) * F.cross_entropy(\n                    prediction, targets\n                ) - alpha * torch.sum(\n                    F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),\n                    dim=1,\n                ).mean()\n            else:\n                loss = F.cross_entropy(prediction, targets)\n\n            # return {\"loss\": loss}\n            return BlipOutputWithLogits(\n                loss=loss,\n                intermediate_output=BlipIntermediateOutput(\n                    image_embeds=image_embeds,\n                    image_embeds_m=image_embeds_m,\n                    encoder_output=encoder_output,\n                    encoder_output_m=encoder_output_m,\n                ),\n                logits=prediction,\n                logits_m=prediction_m,\n            )\n\n        else:\n            return {\"predictions\": prediction, \"targets\": targets}\n\n    def predict(self, samples):\n        output = self.forward(samples, is_train=False)\n        return output\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        # text encoder + multimodal encoder\n        text_encoder = XBertEncoder.from_config(cfg)\n        use_distill = cfg.get(\"use_distill\", True)\n        momentum = cfg.get(\"momentum\", 0.995)\n        num_classes = cfg.get(\"num_classes\", -1)\n        alpha = cfg.get(\"alpha\", 0.4)\n        max_txt_len = cfg.get(\"max_txt_len\", 40)\n\n        assert num_classes > 1, \"Invalid number of classes provided, found {}\".format(\n            num_classes\n        )\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            use_distill=use_distill,\n            alpha=alpha,\n            num_classes=num_classes,\n            momentum=momentum,\n            max_txt_len=max_txt_len,\n        )\n\n        # load pre-trained weights\n        pretrain_path = cfg.get(\"pretrained\", None)\n        if pretrain_path is not None:\n            msg = model.load_from_pretrained(url_or_filename=pretrain_path)\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip_models/blip_feature_extractor.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import BlipOutputFeatures\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"blip_feature_extractor\")\nclass BlipFeatureExtractor(BlipBase):\n    \"\"\"\n    Class for BLIP feature extractor.\n\n    Supported model types:\n        - base: BLIP base model with pre-trained weights from capfilt by BLIP large model.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_feature_extractor\", \"base\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/blip_feature_extractor_base.yaml\",\n        # \"large\": \"configs/models/blip_feature_extractor_large.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_encoder, embed_dim, max_txt_len=40):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        # creating projection layers for ITC\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.max_txt_len = max_txt_len\n\n        self.temp = nn.Parameter(0.07 * torch.ones([]))\n\n    @torch.no_grad()\n    def extract_features(self, samples, mode=\"multimodal\"):\n        \"\"\"\n        Extract features for multimodal or unimodal samples.\n\n        Args:\n            samples (dict): A dictionary of samples, containing the following keys:\n                - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.\n                    Raw images should be preprocessed before being passed to feature extractor.\n                - text_input (list): A list of strings containing the text, length B.\n            mode (str): The mode of feature extraction. Can be either \"multimodal\", \"text\" or \"image\".\n                If \"multimodal\", return image features and multimodal features;\n                if \"text\", return text features;\n                if \"image\", return image features.\n                Default: \"multimodal\".\n\n        Returns:\n            BlipOutputFeatures: A BlipOutputFeatures object containing the features.\n                See lavis/models/blip_models/blip_outputs.py for more details.\n\n        Examples:\n        ```python\n            >>> from PIL import Image\n            >>> from lavis.models import load_model_and_preprocess\n            >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n            >>> caption = \"a large fountain spewing water into the air\"\n            >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"blip_feature_extractor\", is_eval=True)\n            >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n            >>> text_input = txt_processors[\"eval\"](caption)\n\n            >>> sample = {\"image\": image, \"text_input\": [text_input]}\n\n            >>> features_multimodal = model.extract_features(sample)\n            >>> features_multimodal.keys()\n            odict_keys(['image_embeds', 'multimodal_embeds'])\n            >>> features_multimodal.image_embeds.shape\n            torch.Size([1, 197, 768])\n            >>> features_multimodal.multimodal_embeds.shape\n            torch.Size([1, 12, 768])\n\n            >>> features_text = model.extract_features(sample, mode=\"text\")\n            >>> features_text.keys()\n            odict_keys(['text_embeds', 'text_features'])\n            >>> features_text.text_embeds.shape\n            torch.Size([1, 12, 768])\n            >>> features_text.text_features.shape\n            torch.Size([1, 12, 256])\n\n            >>> features_image = model.extract_features(sample, mode=\"image\")\n            >>> features_image.keys()\n            odict_keys(['image_embeds', 'image_features'])\n            >>> features_image.image_embeds.shape\n            torch.Size([1, 197, 768])\n            >>> features_image.image_features.shape\n            torch.Size([1, 197, 256])\n        ```\n        \"\"\"\n        image = samples.get(\"image\")\n        caption = samples.get(\"text_input\")\n\n        # assert mode is one of \"image\", \"text\", \"multimodal\"\n        assert mode in [\n            \"image\",\n            \"text\",\n            \"multimodal\",\n        ], \"mode must be one of 'image', 'text', 'multimodal'\"\n\n        # initalize output\n        image_embeds, text_embeds, multimodal_embeds = None, None, None\n        image_features, text_features = None, None\n\n        if mode == \"image\":\n            assert (\n                image is not None\n            ), \"Image is not provided for mode 'image' or 'multimodal'\"\n            # return image features\n            image_embeds = self.visual_encoder.forward_features(image)\n\n            image_features = self.vision_proj(image_embeds)\n            image_features = F.normalize(image_features, dim=-1)\n\n        elif mode == \"text\":\n            assert (\n                caption is not None\n            ), \"text input is None for mode 'text' or 'multimodal'\"\n\n            text = self.tokenizer(caption, return_tensors=\"pt\", padding=True).to(\n                self.device\n            )\n\n            # return text features\n            text_output = self.text_encoder(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n                mode=\"text\",\n            )\n            text_embeds = text_output.last_hidden_state\n\n            text_features = self.text_proj(text_embeds)\n            text_features = F.normalize(text_features, dim=-1)\n\n        elif mode == \"multimodal\":\n            # return multimodel features\n            image_embeds = self.visual_encoder.forward_features(image)\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n                self.device\n            )\n\n            text = self.tokenizer(caption, return_tensors=\"pt\", padding=True).to(\n                self.device\n            )\n            text.input_ids[:, 0] = self.tokenizer.enc_token_id\n\n            output = self.text_encoder(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n\n            multimodal_embeds = output.last_hidden_state\n\n        return BlipOutputFeatures(\n            image_embeds=image_embeds,\n            image_embeds_proj=image_features,\n            text_embeds=text_embeds,\n            text_embeds_proj=text_features,\n            multimodal_embeds=multimodal_embeds,\n        )\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        # set from_pretrained=True to load weights for 'bert-base-uncased'\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        max_txt_len = cfg.get(\"max_txt_len\", 30)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            embed_dim=embed_dim,\n            max_txt_len=max_txt_len,\n        )\n\n        # load pre-trained weights\n        pretrain_path = cfg.get(\"pretrained\", None)\n        if pretrain_path is not None:\n            msg = model.load_from_pretrained(url_or_filename=pretrain_path)\n        else:\n            warnings.warn(\"No pretrained weights are loaded.\")\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip_models/blip_image_text_matching.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.blip_models.blip import BlipBase\nfrom torch import nn\nfrom lavis.models.med import XBertEncoder\n\nfrom lavis.models.vit import VisionTransformerEncoder\n\n\n@registry.register_model(\"blip_image_text_matching\")\nclass BlipITM(BlipBase):\n    \"\"\"\n    BLIP Image-Text Matching (ITM) model.\n\n    Supported model types:\n        - base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).\n        - large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_image_text_matching\", \"base\")\n        >>> model = load_model(\"blip_image_text_matching\", \"large\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/blip_itm_base.yaml\",\n        \"large\": \"configs/models/blip_itm_large.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.text_encoder = text_encoder\n\n        self.visual_encoder = image_encoder\n\n        self.max_txt_len = max_txt_len\n\n        # creating projection layers for ITC\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n    def forward(self, samples, match_head=\"itm\"):\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n\n        image_embeds = self.visual_encoder.forward_features(image)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        text = self.tokenizer(\n            caption,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(image.device)\n        if match_head == \"itm\":\n            encoder_input_ids = text.input_ids.clone()\n            encoder_input_ids[:, 0] = self.tokenizer.enc_token_id  # extra code\n            output = self.text_encoder(\n                encoder_input_ids,\n                attention_mask=text.attention_mask,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n            itm_output = self.itm_head(output.last_hidden_state[:, 0, :])\n            return itm_output\n\n        elif match_head == \"itc\":\n            text_output = self.text_encoder(\n                text.input_ids,\n                attention_mask=text.attention_mask,\n                return_dict=True,\n                mode=\"text\",\n            )\n            image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n            text_feat = F.normalize(\n                self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1\n            )\n\n            sim = image_feat @ text_feat.t()\n            return sim\n    def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'):\n        # breakpoint()\n        encoder_input_ids = encoder_input_ids.clone()\n        encoder_input_ids = encoder_input_ids[:, 3:]\n        text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long()\n\n        if match_head == 'itm':\n            # encoder_input_ids = encoder_input_ids.clone()\n            encoder_input_ids[:, 0] = self.tokenizer.enc_token_id\n            output = self.text_encoder(encoder_input_ids,\n                                       attention_mask=text_attention_mask,\n                                       encoder_hidden_states=image_embeds,\n                                       encoder_attention_mask=image_atts,\n                                       return_dict=True,\n                                       )\n            # print(output.last_hidden_state.shape)\n            itm_output = self.itm_head(output.last_hidden_state[:, 0, :])\n            itm_output = F.softmax(itm_output, dim=1)[:,1]\n            return itm_output #, mask, token_length\n\n        elif match_head == 'itc':\n            encoder_input_ids[:, 0] = self.tokenizer.cls_token_id\n            text_output = self.text_encoder(encoder_input_ids, attention_mask=text_attention_mask,\n                                            return_dict=True, mode='text')\n            image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n            text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)\n\n            sim = image_feat @ text_feat.t()\n            return sim\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        max_txt_len = cfg.get(\"max_txt_len\", 35)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            embed_dim=embed_dim,\n            max_txt_len=max_txt_len,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n\n\ndef compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6):\n    model.text_encoder.base_model.base_model.encoder.layer[\n        block_num\n    ].crossattention.self.save_attention = True\n\n    output = model({\"image\": visual_input, \"text_input\": text_input}, match_head=\"itm\")\n    loss = output[:, 1].sum()\n\n    model.zero_grad()\n    loss.backward()\n    with torch.no_grad():\n        mask = tokenized_text.attention_mask.view(\n            tokenized_text.attention_mask.size(0), 1, -1, 1, 1\n        )  # (bsz,1,token_len, 1,1)\n        token_length = tokenized_text.attention_mask.sum(dim=-1) - 2\n        token_length = token_length.cpu()\n        # grads and cams [bsz, num_head, seq_len, image_patch]\n        grads = model.text_encoder.base_model.base_model.encoder.layer[\n            block_num\n        ].crossattention.self.get_attn_gradients()\n        cams = model.text_encoder.base_model.base_model.encoder.layer[\n            block_num\n        ].crossattention.self.get_attention_map()\n\n        # assume using vit with 576 num image patch\n        cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask\n        grads = (\n            grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)\n            * mask\n        )\n\n        gradcams = cams * grads\n        gradcam_list = []\n\n        for ind in range(visual_input.size(0)):\n            token_length_ = token_length[ind]\n            gradcam = gradcams[ind].mean(0).cpu().detach()\n            # [enc token gradcam, average gradcam across token, gradcam for individual token]\n            gradcam = torch.cat(\n                (\n                    gradcam[0:1, :],\n                    gradcam[1 : token_length_ + 1, :].sum(dim=0, keepdim=True)\n                    / token_length_,\n                    gradcam[1:, :],\n                )\n            )\n            gradcam_list.append(gradcam)\n            \n    return gradcam_list, output\n"
  },
  {
    "path": "lavis/models/blip_models/blip_nlvr.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path, is_url\nfrom lavis.models.base_model import MomentumDistilationMixin\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import BlipIntermediateOutput, BlipOutput\nfrom lavis.models.blip_models.nlvr_encoder import BertModel\nfrom lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed\nfrom torch import nn\nfrom transformers import BertConfig\n\n\n@registry.register_model(\"blip_nlvr\")\nclass BlipNLVR(BlipBase, MomentumDistilationMixin):\n    \"\"\"\n    Class for BLIP NLVR model.\n\n    Supported model types:\n        - base: model with pre-trained BLIP weights, used as initialization for fine-tuning.\n        - nlvr: finetuned model on NLVR2 dataset.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_nlvr\", \"nlvr\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"nlvr\": \"configs/models/blip_nlvr.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_encoder, num_classes):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n        self.visual_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n        hidden_size = text_encoder.config.hidden_size\n        self.cls_head = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, num_classes),\n        )\n\n    def forward(self, samples, is_train=True):\n        \"\"\"\n        Forward function for training and evaluation.\n\n        Args:\n            samples (dict): a dict of input samples, which contains the following keys:\n                - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384.\n                - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384.\n                - text_input (list): list of strings, each string is a natural language sentence.\n                - label (torch.LongTensor): ground truth label with shape (batch_size,).\n            is_train (bool): whether the model is in training mode.\n                If True, the model will return the loss;\n                If False, the model will return the prediction.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"blip_nlvr\", \"nlvr\")\n            >>> samples = {\n            ...     \"image0\": torch.randn(2, 3, 384, 384),\n            ...     \"image1\": torch.randn(2, 3, 384, 384),\n            ...     \"text_input\": [\"there is a ferret in tall grass\", \"there are lips in one of the images\"],\n            ...     \"label\": torch.tensor([0, 1]),\n            ... }\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['intermediate_output', 'loss'])\n        \"\"\"\n        text = samples[\"text_input\"]\n        text = self.tokenizer(text, padding=\"longest\", return_tensors=\"pt\").to(\n            self.device\n        )\n        text.input_ids[:, 0] = self.tokenizer.enc_token_id\n\n        targets = samples[\"label\"]\n\n        image0 = samples[\"image0\"]\n        image1 = samples[\"image1\"]\n        images = torch.cat([image0, image1], dim=0)\n\n        image_embeds = self.visual_encoder.forward_features(images)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n        image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))\n\n        encoder_output = self.text_encoder(\n            text.input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=[image0_embeds, image1_embeds],\n            encoder_attention_mask=[\n                image_atts[: image0_embeds.size(0)],\n                image_atts[image0_embeds.size(0) :],\n            ],\n            return_dict=True,\n        )\n\n        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])\n\n        if is_train:\n            loss = F.cross_entropy(prediction, targets)\n            # return {\"loss\": loss}\n            return BlipOutput(\n                loss=loss,\n                intermediate_output=BlipIntermediateOutput(\n                    image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0),\n                    encoder_output=encoder_output,\n                ),\n            )\n        else:\n            return {\"predictions\": prediction, \"targets\": targets}\n\n    def predict(self, samples):\n        output = self.forward(samples, is_train=False)\n        return output\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        # text encoder + multimodal encoder\n        bert_config = BertConfig.from_json_file(get_abs_path(cfg[\"med_config_path\"]))\n        text_encoder = BertModel(config=bert_config, add_pooling_layer=False)\n\n        num_classes = cfg.get(\"num_classes\", 3)\n\n        assert num_classes > 1, \"Invalid number of classes provided, found {}\".format(\n            num_classes\n        )\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            num_classes=num_classes,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n\n    def load_from_pretrained(self, url_or_filename):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n        state_dict = checkpoint[\"model\"]\n\n        state_dict[\"visual_encoder.pos_embed\"] = interpolate_pos_embed(\n            state_dict[\"visual_encoder.pos_embed\"], self.visual_encoder\n        )\n\n        for key in list(state_dict.keys()):\n            if \"crossattention.self.\" in key:\n                new_key0 = key.replace(\"self\", \"self0\")\n                new_key1 = key.replace(\"self\", \"self1\")\n                state_dict[new_key0] = state_dict[key]\n                state_dict[new_key1] = state_dict[key]\n            elif \"crossattention.output.dense.\" in key:\n                new_key0 = key.replace(\"dense\", \"dense0\")\n                new_key1 = key.replace(\"dense\", \"dense1\")\n                state_dict[new_key0] = state_dict[key]\n                state_dict[new_key1] = state_dict[key]\n\n        msg = self.load_state_dict(state_dict, strict=False)\n        print(\"load checkpoint from %s\" % url_or_filename)\n        print(f\"missing keys {msg.missing_keys}\")\n        return msg\n"
  },
  {
    "path": "lavis/models/blip_models/blip_outputs.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_outputs import (\n    ModelOutput,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\n\n\n@dataclass\nclass BlipSimilarity(ModelOutput):\n    sim_i2t: torch.FloatTensor = None\n    sim_t2i: torch.FloatTensor = None\n\n    sim_i2t_m: Optional[torch.FloatTensor] = None\n    sim_t2i_m: Optional[torch.FloatTensor] = None\n\n    sim_i2t_targets: Optional[torch.FloatTensor] = None\n    sim_t2i_targets: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass BlipIntermediateOutput(ModelOutput):\n    \"\"\"\n    Data class for intermediate outputs of BLIP models.\n\n    image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).\n    text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).\n\n    image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).\n    text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).\n\n    encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.\n    encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.\n\n    decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.\n    decoder_labels (torch.LongTensor): labels for the captioning loss.\n\n    itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).\n    itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)\n\n    \"\"\"\n\n    # uni-modal features\n    image_embeds: torch.FloatTensor = None\n    text_embeds: Optional[torch.FloatTensor] = None\n\n    image_embeds_m: Optional[torch.FloatTensor] = None\n    text_embeds_m: Optional[torch.FloatTensor] = None\n\n    # intermediate outputs of multimodal encoder\n    encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n    encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n\n    itm_logits: Optional[torch.FloatTensor] = None\n    itm_labels: Optional[torch.LongTensor] = None\n\n    # intermediate outputs of multimodal decoder\n    decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None\n    decoder_labels: Optional[torch.LongTensor] = None\n\n\n@dataclass\nclass BlipOutput(ModelOutput):\n    # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.\n    sims: Optional[BlipSimilarity] = None\n\n    intermediate_output: BlipIntermediateOutput = None\n\n    loss: Optional[torch.FloatTensor] = None\n\n    loss_itc: Optional[torch.FloatTensor] = None\n\n    loss_itm: Optional[torch.FloatTensor] = None\n\n    loss_lm: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass BlipOutputWithLogits(BlipOutput):\n    logits: torch.FloatTensor = None\n    logits_m: torch.FloatTensor = None\n\n\n@dataclass\nclass BlipOutputFeatures(ModelOutput):\n    \"\"\"\n    Data class of features from BlipFeatureExtractor.\n\n    Args:\n        image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional\n        image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional\n        text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional\n        text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional\n\n        The first embedding or feature is for the [CLS] token.\n\n        Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    image_embeds_proj: Optional[torch.FloatTensor] = None\n\n    text_embeds: Optional[torch.FloatTensor] = None\n    text_embeds_proj: Optional[torch.FloatTensor] = None\n\n    multimodal_embeds: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "lavis/models/blip_models/blip_pretrain.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin\nfrom lavis.models.blip_models import tie_encoder_decoder_weights\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import (\n    BlipOutput,\n    BlipSimilarity,\n    BlipIntermediateOutput,\n)\nfrom lavis.models.med import XBertEncoder, XBertLMHeadDecoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"blip_pretrain\")\nclass BlipPretrain(BlipBase, SharedQueueMixin, MomentumDistilationMixin):\n    \"\"\"\n    BLIP pretrain model.\n\n    Supported model types:\n        - base: BLIP base model before pretraining.\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/blip_pretrain_base.yaml\",\n        # \"large\": \"configs/models/blip_pretrain_large.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        text_decoder,\n        queue_size,\n        alpha=0.4,\n        embed_dim=256,\n        momentum=0.995,\n        tie_enc_dec_weights=True,\n        max_txt_len=30,\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        text_encoder.resize_token_embeddings(len(self.tokenizer))\n        text_decoder.resize_token_embeddings(len(self.tokenizer))\n\n        if tie_enc_dec_weights:\n            tie_encoder_decoder_weights(\n                encoder=text_encoder,\n                decoder=text_decoder.bert,\n                base_model_prefix=\"\",\n                skip_key=\"/attention\",\n            )\n\n        self.visual_encoder = image_encoder\n\n        self.text_encoder = text_encoder\n        self.text_decoder = text_decoder\n\n        # creating projection layers for ITC\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n        # create the momentum encoder\n        self.visual_encoder_m = deepcopy(self.visual_encoder)\n        self.text_encoder_m = deepcopy(self.text_encoder)\n\n        self.vision_proj_m = deepcopy(self.vision_proj)\n        self.text_proj_m = deepcopy(self.text_proj)\n\n        self.model_pairs = [\n            [self.visual_encoder, self.visual_encoder_m],\n            [self.text_encoder, self.text_encoder_m],\n            [self.vision_proj, self.vision_proj_m],\n            [self.text_proj, self.text_proj_m],\n        ]\n        self.copy_params()\n\n        # create the queue\n        self.register_buffer(\"image_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"text_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"queue_ptr\", torch.zeros(1, dtype=torch.long))\n\n        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)\n        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)\n\n        self.queue_size = queue_size\n        self.momentum = momentum\n        self.temp = nn.Parameter(0.07 * torch.ones([]))\n\n        self.alpha = alpha\n        self.max_txt_len = max_txt_len\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))\n\n    def forward(self, samples):\n\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. Default: H=224, W=224.\n                - text_input (list): A list of length batch_size, each element is a string of text/caption.\n                - epoch (int): The current epoch.\n                - iters (int): The current iteration.\n                - num_iters_per_epoch (int): The number of iterations per epoch.\n\n        Returns:\n            BlipOutput: A BlipOutput object containing loss and intermediate output. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"blip_pretrain\", \"base\")\n            >>> images = torch.randn(4, 3, 224, 224)\n            >>> text_input = [\"caption of image 1\", \"another caption of image 1\", \"caption of image 2\", \"caption of image 3\"]\n            >>> samples = {\"image\": images, \"text_input\": text_input, \"epoch\": 0, \"iters\": 0, \"num_iters_per_epoch\": 100}\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm', 'loss_lm'])\n\n            >>> output.intermediate_output.keys()\n            odict_keys(['image_embeds', 'text_embeds', 'image_embeds_m', 'text_embeds_m', 'encoder_output', 'encoder_output_neg', 'itm_logits', 'itm_labels', 'decoder_output', 'decoder_labels'])\n            >>> output.intermediate_output.image_embeds.shape\n            >>> # shape: (batch_size, num_patches, embed_dim)\n            torch.Size([4, 197, 768])\n            >>> output.intermediate_output.text_embeds.shape\n            >>> # shape: (batch_size, max_txt_len, embed_dim)\n            torch.Size([4, 30, 768])\n            >>> output.intermediate_output.image_embeds_m.shape\n            >>> # shape: (batch_size, num_patches, embed_dim)\n            torch.Size([4, 197, 768])\n            >>> output.intermediate_output.text_embeds_m.shape\n            >>> # shape: (batch_size, max_txt_len, embed_dim)\n            torch.Size([4, 30, 768])\n            >>> output.intermediate_output.itm_logits.shape\n            >>> # shape: (batch_size * 3, 2)\n            torch.Size([12, 2])\n            >>> output.intermediate_output.itm_labels.shape\n            >>> # shape: (batch_size * 3,)\n            torch.Size([12])\n            >>> output.intermediate_output.encoder_output.last_hidden_state.shape\n            >>> # shape: (batch_size, max_txt_len, embed_dim)\n            torch.Size([4, 30, 768])\n            >>> output.intermediate_output.encoder_output_m.last_hidden_state.shape\n            >>> # shape: (batch_size, max_txt_len, embed_dim)\n            torch.Size([4, 30, 768])\n            >>> output.intermediate_output.decoder_output.logits.shape\n            >>> # shape: (batch_size, max_txt_len, vocab_size)\n            torch.Size([4, 30, 30524])\n            >>> output.intermediate_output.decoder_labels.shape\n            >>> # shape: (batch_size, max_txt_len)\n            torch.Size([4, 30])\n        \"\"\"\n\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n\n        alpha = self.alpha * self._rampup_factor(\n            epoch=samples[\"epoch\"],\n            iters=samples[\"iters\"],\n            num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n        )\n\n        with torch.no_grad():\n            self.temp.clamp_(0.001, 0.5)\n\n        # image embeddings and features\n        image_embeds = self.visual_encoder.forward_features(image)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n\n        text = self.tokenizer(\n            caption,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(image.device)\n\n        # text embeddings and features\n        text_output = self.text_encoder.forward_text(text)\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)\n\n        # get momentum features\n        with torch.no_grad():\n            self._momentum_update()\n            image_embeds_m = self.visual_encoder_m(image)\n            image_feat_m = F.normalize(\n                self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1\n            )\n            image_feat_all = torch.cat(\n                [image_feat_m.t(), self.image_queue.clone().detach()], dim=1\n            )\n\n            text_output_m = self.text_encoder_m.forward_text(text)\n            text_embeds_m = text_output_m.last_hidden_state\n            text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)\n            text_feat_all = torch.cat(\n                [text_feat_m.t(), self.text_queue.clone().detach()], dim=1\n            )\n\n            sim_i2t_m = image_feat_m @ text_feat_all / self.temp\n            sim_t2i_m = text_feat_m @ image_feat_all / self.temp\n\n            sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)\n            sim_targets.fill_diagonal_(1)\n\n            sim_i2t_targets = (\n                alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets\n            )\n            sim_t2i_targets = (\n                alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets\n            )\n\n        sim_i2t = image_feat @ text_feat_all / self.temp\n        sim_t2i = text_feat @ image_feat_all / self.temp\n\n        loss_i2t = -torch.sum(\n            F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1\n        ).mean()\n        loss_t2i = -torch.sum(\n            F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1\n        ).mean()\n\n        loss_itc = (loss_i2t + loss_t2i) / 2\n\n        self._dequeue_and_enqueue(image_feat_m, text_feat_m)\n\n        # Image-text Matching\n        encoder_input_ids = text.input_ids.clone()\n        encoder_input_ids[:, 0] = self.tokenizer.enc_token_id\n\n        # forward the positve image-text pair\n        bs = image.size(0)\n        output_pos = self.text_encoder(\n            encoder_input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        with torch.no_grad():\n            weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) + 1e-4\n            weights_t2i.fill_diagonal_(0)\n            weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) + 1e-4\n            weights_i2t.fill_diagonal_(0)\n\n        # select a negative image for each text\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n            image_embeds_neg.append(image_embeds[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n        # select a negative text for each image\n        text_ids_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n            text_ids_neg.append(encoder_input_ids[neg_idx])\n            text_atts_neg.append(text.attention_mask[neg_idx])\n\n        text_ids_neg = torch.stack(text_ids_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0)\n        text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)\n\n        image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)\n        image_atts_all = torch.cat([image_atts, image_atts], dim=0)\n\n        output_neg = self.text_encoder(\n            text_ids_all,\n            attention_mask=text_atts_all,\n            encoder_hidden_states=image_embeds_all,\n            encoder_attention_mask=image_atts_all,\n            return_dict=True,\n        )\n\n        vl_embeddings = torch.cat(\n            [\n                output_pos.last_hidden_state[:, 0, :],\n                output_neg.last_hidden_state[:, 0, :],\n            ],\n            dim=0,\n        )\n        itm_logits = self.itm_head(vl_embeddings)\n\n        itm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(image.device)\n        loss_itm = F.cross_entropy(itm_logits, itm_labels)\n\n        # LM\n        decoder_input_ids = text.input_ids.clone()\n        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id\n        decoder_targets = decoder_input_ids.masked_fill(\n            decoder_input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        decoder_output = self.text_decoder(\n            decoder_input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            labels=decoder_targets,\n            return_dict=True,\n        )\n\n        loss_lm = decoder_output.loss\n\n        return BlipOutput(\n            loss=loss_itc + loss_itm + loss_lm,\n            loss_itc=loss_itc,\n            loss_itm=loss_itm,\n            loss_lm=loss_lm,\n            sims=BlipSimilarity(\n                sim_i2t=sim_i2t,\n                sim_t2i=sim_t2i,\n                sim_i2t_m=sim_i2t_m,\n                sim_t2i_m=sim_t2i_m,\n                sim_i2t_targets=sim_i2t_targets,\n                sim_t2i_targets=sim_t2i_targets,\n            ),\n            intermediate_output=BlipIntermediateOutput(\n                image_embeds=image_embeds,\n                text_embeds=text_embeds,\n                image_embeds_m=image_embeds_m,\n                text_embeds_m=text_embeds_m,\n                encoder_output=output_pos,\n                encoder_output_neg=output_neg,\n                itm_logits=itm_logits,\n                itm_labels=itm_labels,\n                decoder_output=decoder_output,\n                decoder_labels=decoder_targets,\n            ),\n        )\n\n    def reset_queue_ptr(self):\n        self.queue_ptr = torch.zeros(1, dtype=torch.long)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        # set from_pretrained=True to load weights for 'bert-base-uncased'\n        image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True)\n        text_encoder = XBertEncoder.from_config(cfg, from_pretrained=True)\n        text_decoder = XBertLMHeadDecoder.from_config(cfg, from_pretrained=True)\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        momentum = cfg.get(\"momentum\", 0.995)\n        alpha = cfg.get(\"alpha\", 0.4)\n        max_txt_len = cfg.get(\"max_txt_len\", 30)\n        queue_size = cfg.get(\"queue_size\", 57600)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            text_decoder=text_decoder,\n            embed_dim=embed_dim,\n            queue_size=queue_size,\n            momentum=momentum,\n            alpha=alpha,\n            tie_enc_dec_weights=True,\n            max_txt_len=max_txt_len,\n        )\n\n        # [IMPORTANT] to reset queue pointer to 0.\n        # Otherwise when updating last batch in the queue, the batch size and remaining queue length may be un-equal.\n        model.reset_queue_ptr()\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip_models/blip_retrieval.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.albef_models import compute_sim_matrix\nfrom lavis.models.base_model import (\n    MomentumDistilationMixin,\n    SharedQueueMixin,\n    all_gather_with_grad,\n    concat_all_gather,\n)\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import (\n    BlipOutput,\n    BlipSimilarity,\n    BlipIntermediateOutput,\n)\nfrom lavis.models.med import XBertEncoder\nfrom lavis.models.vit import VisionTransformerEncoder\nfrom torch import nn\n\n\n@registry.register_model(\"blip_retrieval\")\nclass BlipRetrieval(BlipBase, MomentumDistilationMixin, SharedQueueMixin):\n    \"\"\"\n    BLIP retrieval model.\n\n    Supported model types:\n        - coco: fine-tuned BLIP base model on COCO dataset (Karpathy split).\n        - flickr: fine-tuned BLIP base model on Flickr30k dataset.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_retrieval\", \"coco\")\n        >>> model = load_model(\"blip_retrieval\", \"flickr\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"coco\": \"configs/models/blip_retrieval_coco.yaml\",\n        \"flickr\": \"configs/models/blip_retrieval_flickr.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_encoder,\n        text_encoder,\n        queue_size,\n        alpha=0.4,\n        embed_dim=256,\n        momentum=0.995,\n        negative_all_rank=False,\n        max_txt_len=35,\n    ):\n        \"\"\" \"\"\"\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n\n        self.text_encoder = text_encoder\n\n        # creating projection layers for ITC\n        text_width = text_encoder.config.hidden_size\n        vision_width = image_encoder.vision_width\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)\n\n        self.itm_head = nn.Linear(text_width, 2)\n\n        # create the momentum encoder\n        self.visual_encoder_m = deepcopy(self.visual_encoder)\n        self.text_encoder_m = deepcopy(self.text_encoder)\n\n        self.vision_proj_m = deepcopy(self.vision_proj)\n        self.text_proj_m = deepcopy(self.text_proj)\n\n        self.model_pairs = [\n            [self.visual_encoder, self.visual_encoder_m],\n            [self.text_encoder, self.text_encoder_m],\n            [self.vision_proj, self.vision_proj_m],\n            [self.text_proj, self.text_proj_m],\n        ]\n        self.copy_params()\n\n        # create the queue\n        self.register_buffer(\"image_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"text_queue\", torch.randn(embed_dim, queue_size))\n        self.register_buffer(\"idx_queue\", torch.full((1, queue_size), -100))\n        self.register_buffer(\"queue_ptr\", torch.zeros(1, dtype=torch.long))\n\n        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)\n        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)\n\n        self.queue_size = queue_size\n        self.momentum = momentum\n        self.temp = nn.Parameter(0.07 * torch.ones([]))\n\n        self.alpha = alpha\n        self.max_txt_len = max_txt_len\n\n        self.negative_all_rank = negative_all_rank\n\n    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):\n        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))\n\n    def forward(self, samples):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images.\n                - text_input (list): A list of length batch_size, each element is a string of text/caption.\n                - image_id (torch.Tensor): A tensor of shape (batch_size, ). The image ids, used to identify same images in batch.\n                - epoch (int): The current epoch.\n                - iters (int): The current iteration.\n                - num_iters_per_epoch (int): The number of iterations per epoch.\n\n        Returns:\n            BlipOutput: A BlipOutput object. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details.\n\n        Examples:\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"blip_retrieval\", \"coco\")\n            >>> images = torch.randn(4, 3, 384, 384)\n            >>> text_input = [\"caption of image 1\", \"another caption of image 1\", \"caption of image 2\", \"caption of image 3\"]\n            >>> image_id = torch.tensor([1, 1, 2, 3])\n            >>> samples = {\"image\": images, \"text_input\": text_input, \"image_id\": image_id, \"epoch\": 0, \"iters\": 0, \"num_iters_per_epoch\": 100}\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm'])\n        \"\"\"\n        image = samples[\"image\"]\n        caption = samples[\"text_input\"]\n        idx = samples[\"image_id\"]\n\n        alpha = self.alpha * self._rampup_factor(\n            epoch=samples[\"epoch\"],\n            iters=samples[\"iters\"],\n            num_iters_per_epoch=samples[\"num_iters_per_epoch\"],\n        )\n\n        with torch.no_grad():\n            self.temp.clamp_(0.001, 0.5)\n\n        image_embeds = self.visual_encoder.forward_features(image)\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n\n        text = self.tokenizer(\n            caption,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(image.device)\n\n        text_output = self.text_encoder.forward_text(text)\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)\n\n        # Image-text Contrastive Learning\n        idx = idx.view(-1, 1)\n        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)\n        pos_idx = torch.eq(idx, idx_all).float()\n        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)\n\n        # get momentum features\n        with torch.no_grad():\n            self._momentum_update()\n            image_embeds_m = self.visual_encoder_m(image)\n            image_feat_m = F.normalize(\n                self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1\n            )\n            image_feat_m_all = torch.cat(\n                [image_feat_m.t(), self.image_queue.clone().detach()], dim=1\n            )\n\n            text_output_m = self.text_encoder_m.forward_text(text)\n            text_embeds_m = text_output_m.last_hidden_state\n            text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)\n            text_feat_m_all = torch.cat(\n                [text_feat_m.t(), self.text_queue.clone().detach()], dim=1\n            )\n\n            sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp\n            sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp\n\n            sim_i2t_targets = (\n                alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets\n            )\n            sim_t2i_targets = (\n                alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets\n            )\n\n        sim_i2t = image_feat @ text_feat_m_all / self.temp\n        sim_t2i = text_feat @ image_feat_m_all / self.temp\n\n        loss_i2t = -torch.sum(\n            F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1\n        ).mean()\n        loss_t2i = -torch.sum(\n            F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1\n        ).mean()\n\n        loss_itc = (loss_i2t + loss_t2i) / 2\n\n        self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)\n\n        # Image-text Matching\n        encoder_input_ids = text.input_ids.clone()\n        encoder_input_ids[:, 0] = self.tokenizer.enc_token_id\n\n        # forward the positve image-text pair\n        bs = image.size(0)\n        output_pos = self.text_encoder(\n            encoder_input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        idxs = concat_all_gather(idx)\n        if self.negative_all_rank:\n            # compute sample similarity\n            with torch.no_grad():\n                mask = torch.eq(idx, idxs.t())\n\n                image_feat_world = concat_all_gather(image_feat)\n                text_feat_world = concat_all_gather(text_feat)\n\n                sim_i2t = image_feat @ text_feat_world.t() / self.temp\n                sim_t2i = text_feat @ image_feat_world.t() / self.temp\n\n                weights_i2t = F.softmax(sim_i2t, dim=1)\n                weights_i2t.masked_fill_(mask, 0)\n\n                weights_t2i = F.softmax(sim_t2i, dim=1)\n                weights_t2i.masked_fill_(mask, 0)\n\n            image_embeds_world = all_gather_with_grad(image_embeds)\n\n            # select a negative image (from all ranks) for each text\n            image_embeds_neg = []\n            for b in range(bs):\n                neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n                image_embeds_neg.append(image_embeds_world[neg_idx])\n            image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n            # select a negative text (from all ranks) for each image\n            input_ids_world = concat_all_gather(encoder_input_ids)\n            att_mask_world = concat_all_gather(text.attention_mask)\n\n            text_ids_neg = []\n            text_atts_neg = []\n            for b in range(bs):\n                neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n                text_ids_neg.append(input_ids_world[neg_idx])\n                text_atts_neg.append(att_mask_world[neg_idx])\n\n        else:\n            with torch.no_grad():\n                mask = torch.eq(idx, idx.t())\n\n                sim_i2t = image_feat @ text_feat.t() / self.temp\n                sim_t2i = text_feat @ image_feat.t() / self.temp\n\n                weights_i2t = F.softmax(sim_i2t, dim=1)\n                weights_i2t.masked_fill_(mask, 0)\n\n                weights_t2i = F.softmax(sim_t2i, dim=1)\n                weights_t2i.masked_fill_(mask, 0)\n\n            # select a negative image (from same rank) for each text\n            image_embeds_neg = []\n            for b in range(bs):\n                neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n                image_embeds_neg.append(image_embeds[neg_idx])\n            image_embeds_neg = torch.stack(image_embeds_neg, dim=0)\n\n            # select a negative text (from same rank) for each image\n            text_ids_neg = []\n            text_atts_neg = []\n            for b in range(bs):\n                neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n                text_ids_neg.append(encoder_input_ids[neg_idx])\n                text_atts_neg.append(text.attention_mask[neg_idx])\n\n        text_ids_neg = torch.stack(text_ids_neg, dim=0)\n        text_atts_neg = torch.stack(text_atts_neg, dim=0)\n\n        text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0)\n        text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)\n\n        image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)\n        image_atts_all = torch.cat([image_atts, image_atts], dim=0)\n\n        output_neg = self.text_encoder(\n            text_ids_all,\n            attention_mask=text_atts_all,\n            encoder_hidden_states=image_embeds_all,\n            encoder_attention_mask=image_atts_all,\n            return_dict=True,\n        )\n\n        vl_embeddings = torch.cat(\n            [\n                output_pos.last_hidden_state[:, 0, :],\n                output_neg.last_hidden_state[:, 0, :],\n            ],\n            dim=0,\n        )\n        itm_logits = self.itm_head(vl_embeddings)\n\n        itm_labels = torch.cat(\n            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],\n            dim=0,\n        ).to(self.device)\n        loss_itm = F.cross_entropy(itm_logits, itm_labels)\n\n        return BlipOutput(\n            loss=loss_itc + loss_itm,\n            loss_itc=loss_itc,\n            loss_itm=loss_itm,\n            sims=BlipSimilarity(\n                sim_i2t=sim_i2t,\n                sim_t2i=sim_t2i,\n                sim_i2t_m=sim_i2t_m,\n                sim_t2i_m=sim_t2i_m,\n                sim_i2t_targets=sim_i2t_targets,\n                sim_t2i_targets=sim_t2i_targets,\n            ),\n            intermediate_output=BlipIntermediateOutput(\n                image_embeds=image_embeds,\n                image_embeds_m=image_embeds_m,\n                text_embeds=text_embeds,\n                text_embeds_m=text_embeds_m,\n                encoder_output=output_pos,\n                encoder_output_neg=output_neg,\n                itm_logits=itm_logits,\n                itm_labels=itm_labels,\n            ),\n        )\n\n    def reset_queue_ptr(self):\n        self.queue_ptr = torch.zeros(1, dtype=torch.long)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        # set from_pretrained=True to load weights for 'bert-base-uncased'\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n        text_encoder = XBertEncoder.from_config(cfg)\n\n        embed_dim = cfg.get(\"embed_dim\", 256)\n        momentum = cfg.get(\"momentum\", 0.995)\n        alpha = cfg.get(\"alpha\", 0.4)\n        negative_all_rank = cfg.get(\"negative_all_rank\", False)\n\n        queue_size = cfg.get(\"queue_size\", 0)\n        max_txt_len = cfg.get(\"max_txt_len\", 35)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            queue_size=queue_size,\n            alpha=alpha,\n            embed_dim=embed_dim,\n            momentum=momentum,\n            negative_all_rank=negative_all_rank,\n            max_txt_len=max_txt_len,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n        model.reset_queue_ptr()\n\n        return model\n\n    def compute_sim_matrix(self, data_loader, task_cfg):\n        \"\"\"\n        Compute similarity i2t, t2i matrix for the given data loader.\n        \"\"\"\n        k_test = task_cfg.k_test\n\n        return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)\n"
  },
  {
    "path": "lavis/models/blip_models/blip_vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import tile\nfrom lavis.models.blip_models.blip import BlipBase\nfrom lavis.models.blip_models.blip_outputs import (\n    BlipOutput,\n    BlipIntermediateOutput,\n)\nfrom lavis.models.med import XBertEncoder, XBertLMHeadDecoder\nfrom lavis.models.vit import VisionTransformerEncoder\n\n\n@registry.register_model(\"blip_vqa\")\nclass BlipVQA(BlipBase):\n    \"\"\"\n    BLIP VQA models.\n\n    Supported model types:\n        - base: vqa model initialized with pre-trained BLIP base model on 115M image-text pairs after CapFilt; not fine-tuned.\n        - vqav2: fine-tuned BLIP base model on VQA v2.0 dataset.\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip_vqa\", \"vqav2\")\n        >>> model = load_model(\"blip_vqa\", \"okvqa\")\n        >>> model = load_model(\"blip_vqa\", \"aokvqa\")\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"vqav2\": \"configs/models/blip_vqav2.yaml\",\n        \"okvqa\": \"configs/models/blip_vqa_okvqa.yaml\",\n        \"aokvqa\": \"configs/models/blip_vqa_aokvqa.yaml\",\n    }\n\n    def __init__(self, image_encoder, text_encoder, text_decoder, max_txt_len=35):\n        super().__init__()\n        self.tokenizer = self.init_tokenizer()\n\n        self.visual_encoder = image_encoder\n\n        self.text_encoder = text_encoder\n        self.text_decoder = text_decoder\n\n        self.max_txt_len = max_txt_len\n\n    def forward(self, samples):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (list): A list of strings, each string is a question\n                - answer (list): A list of strings, each string is an answer\n                - weight (torch.Tensor): A tensor used to weigh each answer in the loss computation.\n                   The shape of the tensor is (sum(n_answers),)\n                - n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers\n                     for each question in the batch.\n\n        Returns:\n            A BlipOutput object containing loss and intermediate outputs,\n            see :class:`lavis.models.blip_outputs.BlipOutput` for more details.\n\n        Examples:\n        ```python\n            >>> import torch\n            >>> from lavis.models import load_model\n            >>> model = load_model(\"blip_vqa\")\n            >>> samples = {\n            ...     \"image\": torch.rand(2, 3, 480, 480),\n            ...     \"text_input\": [\"What is this?\", \"What is that?\"],\n            ...     \"answer\": [\"cat\", \"cat\", \"dog\"],\n            ...     \"weight\": torch.tensor([1.0, 1.0, 1.0]),\n            ...     \"n_answers\": torch.tensor([2, 1]),\n            ... }\n            >>> output = model(samples)\n            >>> output.keys()\n            odict_keys(['intermediate_output', 'loss'])\n            >>> output.intermediate_output.keys()\n            odict_keys(['image_embeds', 'encoder_output', 'decoder_output', 'decoder_labels'])\n        ```\n        \"\"\"\n        encoder_output, image_embeds = self.forward_encoder(samples)\n        loss, decoder_output, decoder_targets = self.forward_decoder(\n            samples=samples, encoder_out=encoder_output\n        )\n\n        return BlipOutput(\n            loss=loss,\n            intermediate_output=BlipIntermediateOutput(\n                image_embeds=image_embeds,\n                encoder_output=encoder_output,\n                decoder_output=decoder_output,\n                decoder_labels=decoder_targets,\n            ),\n        )\n\n    def forward_encoder(self, samples):\n        questions = samples[\"text_input\"]\n        questions = self.tokenizer(\n            questions,\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        questions.input_ids[:, 0] = self.tokenizer.enc_token_id\n        samples.update({\"tokenized_text\": questions})\n\n        image_embeds = self.visual_encoder.forward_features(samples[\"image\"])\n        encoder_output = self.text_encoder.forward_automask(\n            tokenized_text=samples[\"tokenized_text\"], visual_embeds=image_embeds\n        )\n\n        return encoder_output, image_embeds\n\n    def forward_decoder(self, samples, encoder_out, **kwargs):\n        answers = self.tokenizer(\n            samples[\"answer\"], padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n        answers.input_ids[:, 0] = self.tokenizer.bos_token_id\n        answer_targets = answers.input_ids.masked_fill(\n            answers.input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        question_states = []\n        question_atts = []\n\n        question = samples[\"tokenized_text\"]\n        question_output = encoder_out\n\n        for b, n in enumerate(samples[\"n_answers\"]):\n            question_states += [question_output.last_hidden_state[b]] * n\n            question_atts += [question.attention_mask[b]] * n\n\n        question_states = torch.stack(question_states, dim=0)\n        question_atts = torch.stack(question_atts, dim=0)\n\n        answer_output = self.text_decoder(\n            answers.input_ids,\n            attention_mask=answers.attention_mask,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            labels=answer_targets,\n            return_dict=True,\n            reduction=\"none\",\n        )\n\n        loss = samples[\"weight\"] * answer_output.loss\n        bsz = samples[\"image\"].size(0)\n\n        loss = loss.sum() / bsz\n\n        return loss, answer_output, answer_targets\n\n    def predict_answers(\n        self,\n        samples,\n        num_beams=3,\n        inference_method=\"rank\",\n        max_len=10,\n        min_len=1,\n        num_ans_candidates=128,\n        answer_list=None,\n        **kwargs\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (str or [str]): String or a list of strings, each string is a question.\n                                             The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            inference_method (str): Inference method. One of \"rank\", \"generate\".\n                - If \"rank\", the model will return answers with the highest probability from the answer list.\n                - If \"generate\", the model will generate answers.\n            max_len (int): Maximum length of generated answers.\n            min_len (int): Minimum length of generated answers.\n            num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability.\n            answer_list (list): A list of strings, each string is an answer.\n\n        Returns:\n            List: A list of strings, each string is an answer.\n\n        Examples:\n        ```python\n            >>> from PIL import Image\n            >>> from lavis.models import load_model_and_preprocess\n            >>> model, vis_processors, txt_processors = load_model_and_preprocess(\"blip_vqa\", \"vqav2\")\n            >>> raw_image = Image.open(\"docs/data/merlion.png\").convert(\"RGB\")\n            >>> question = \"Which city is this photo taken?\"\n            >>> image = vis_processors[\"eval\"](raw_image).unsqueeze(0)\n            >>> question = txt_processors[\"eval\"](question)\n            >>> samples = {\"image\": image, \"text_input\": [question]}\n            >>> answers = model.predict_answers(samples)\n            >>> answers\n            ['singapore']\n            >>> answer_list = [\"Singapore\", \"London\", \"Palo Alto\", \"Tokyo\"]\n            >>> answers = model.predict_answers(samples, answer_list=answer_list)\n            >>> answers\n            ['Singapore']\n        ```\n        \"\"\"\n        assert inference_method in [\n            \"rank\",\n            \"generate\",\n        ], \"Inference method must be one of 'rank' or 'generate', got {}.\".format(\n            inference_method\n        )\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n\n        assert len(samples[\"text_input\"]) == samples[\"image\"].size(\n            0\n        ), \"The number of questions must be equal to the batch size.\"\n\n        if inference_method == \"generate\":\n            return self._generate_answers(\n                samples, num_beams=num_beams, max_length=max_len, min_length=min_len\n            )\n        elif inference_method == \"rank\":\n            assert answer_list is not None, \"answer_list must be provided for ranking\"\n\n            num_ans_candidates = min(num_ans_candidates, len(answer_list))\n\n            return self._rank_answers(\n                samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates\n            )\n\n    def _generate_answers(self, samples, num_beams=3, max_length=10, min_length=1):\n        encoder_out, _ = self.forward_encoder(samples)\n\n        question_output = encoder_out\n\n        question_states = question_output.last_hidden_state.repeat_interleave(\n            num_beams, dim=0\n        )\n        question_atts = torch.ones(question_states.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        model_kwargs = {\n            \"encoder_hidden_states\": question_states,\n            \"encoder_attention_mask\": question_atts,\n        }\n\n        bsz = samples[\"image\"].size(0)\n        bos_ids = torch.full(\n            (bsz, 1), fill_value=self.tokenizer.bos_token_id, device=self.device\n        )\n\n        outputs = self.text_decoder.generate(\n            input_ids=bos_ids,\n            max_length=max_length,\n            min_length=min_length,\n            num_beams=num_beams,\n            eos_token_id=self.tokenizer.sep_token_id,\n            pad_token_id=self.tokenizer.pad_token_id,\n            **model_kwargs\n        )\n\n        # collect answers\n        answers = []\n        for output in outputs:\n            answer = self.tokenizer.decode(output, skip_special_tokens=True)\n            answers.append(answer)\n\n        return answers\n\n    def _rank_answers(self, samples, answer_list, num_ans_candidates):\n        \"\"\"\n        Generate the first token of answers using decoder and select ${num_ans_candidates}\n        most probable ones. Then select answers from answer list, which start with the probable tokens.\n        Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss.\n        Return the answers that minimize the losses as result.\n\n        \"\"\"\n        answer_candidates = self.tokenizer(\n            answer_list, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n        answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id\n\n        answer_ids = answer_candidates.input_ids\n        answer_atts = answer_candidates.attention_mask\n\n        question_output, _ = self.forward_encoder(samples)\n        question_states = question_output.last_hidden_state\n\n        tokenized_question = samples[\"tokenized_text\"]\n        question_atts = tokenized_question.attention_mask\n\n        num_ques = question_states.size(0)\n        start_ids = answer_ids[0, 0].repeat(num_ques, 1)  # bos token\n\n        start_output = self.text_decoder(\n            start_ids,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            return_dict=True,\n            reduction=\"none\",\n        )\n        logits = start_output.logits[:, 0, :]  # first token's logit\n\n        # topk_probs: top-k probability\n        # topk_ids: [num_question, k]\n        answer_first_token = answer_ids[:, 1]\n        prob_first_token = F.softmax(logits, dim=1).index_select(\n            dim=1, index=answer_first_token\n        )\n        topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1)\n\n        # answer input: [num_question*k, answer_len]\n        input_ids = []\n        input_atts = []\n        for b, topk_id in enumerate(topk_ids):\n            input_ids.append(answer_ids.index_select(dim=0, index=topk_id))\n            input_atts.append(answer_atts.index_select(dim=0, index=topk_id))\n        input_ids = torch.cat(input_ids, dim=0)\n        input_atts = torch.cat(input_atts, dim=0)\n\n        targets_ids = input_ids.masked_fill(\n            input_ids == self.tokenizer.pad_token_id, -100\n        )\n\n        # repeat encoder's output for top-k answers\n        question_states = tile(question_states, 0, num_ans_candidates)\n        question_atts = tile(question_atts, 0, num_ans_candidates)\n\n        output = self.text_decoder(\n            input_ids,\n            attention_mask=input_atts,\n            encoder_hidden_states=question_states,\n            encoder_attention_mask=question_atts,\n            labels=targets_ids,\n            return_dict=True,\n            reduction=\"none\",\n        )\n\n        log_probs_sum = -output.loss\n        log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates)\n\n        max_topk_ids = log_probs_sum.argmax(dim=1)\n        max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]\n\n        answers = [answer_list[max_id] for max_id in max_ids]\n\n        return answers\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        image_encoder = VisionTransformerEncoder.from_config(cfg)\n\n        # text encoder + multimodal encoder\n        text_encoder = XBertEncoder.from_config(cfg)\n        text_decoder = XBertLMHeadDecoder.from_config(cfg)\n\n        max_txt_len = cfg.get(\"max_txt_len\", 35)\n\n        model = cls(\n            image_encoder=image_encoder,\n            text_encoder=text_encoder,\n            text_decoder=text_decoder,\n            max_txt_len=max_txt_len,\n        )\n\n        model.load_checkpoint_from_config(cfg)\n\n        return model\n"
  },
  {
    "path": "lavis/models/blip_models/nlvr_encoder.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport math\nfrom typing import Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import Tensor, device, nn\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n)\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.models.bert.configuration_bert import BertConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word and position embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id\n        )\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size\n        )\n\n        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n        # any TensorFlow checkpoint file\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1))\n        )\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n\n        self.config = config\n\n    def forward(\n        self,\n        input_ids=None,\n        position_ids=None,\n        inputs_embeds=None,\n        past_key_values_length=0,\n    ):\n        if input_ids is not None:\n            input_shape = input_ids.size()\n        else:\n            input_shape = inputs_embeds.size()[:-1]\n\n        seq_length = input_shape[1]\n\n        if position_ids is None:\n            position_ids = self.position_ids[\n                :, past_key_values_length : seq_length + past_key_values_length\n            ]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        embeddings = inputs_embeds\n\n        if self.position_embedding_type == \"absolute\":\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings += position_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, is_cross_attention):\n        super().__init__()\n        self.config = config\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(\n            config, \"embedding_size\"\n        ):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        if is_cross_attention:\n            self.key = nn.Linear(config.encoder_width, self.all_head_size)\n            self.value = nn.Linear(config.encoder_width, self.all_head_size)\n        else:\n            self.key = nn.Linear(config.hidden_size, self.all_head_size)\n            self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(\n                2 * config.max_position_embeddings - 1, self.attention_head_size\n            )\n        self.save_attention = False\n\n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n\n    def get_attn_gradients(self):\n        return self.attn_gradients\n\n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n\n    def get_attention_map(self):\n        return self.attention_map\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (\n            self.num_attention_heads,\n            self.attention_head_size,\n        )\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        mixed_query_layer = self.query(hidden_states)\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            seq_length = hidden_states.size()[1]\n            position_ids_l = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(-1, 1)\n            position_ids_r = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(\n                distance + self.max_position_embeddings - 1\n            )\n            positional_embedding = positional_embedding.to(\n                dtype=query_layer.dtype\n            )  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                relative_position_scores_key = torch.einsum(\n                    \"bhrd,lrd->bhlr\", key_layer, positional_embedding\n                )\n                attention_scores = (\n                    attention_scores\n                    + relative_position_scores_query\n                    + relative_position_scores_key\n                )\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        if is_cross_attention and self.save_attention:\n            self.save_attention_map(attention_probs)\n            attention_probs.register_hook(self.save_attn_gradients)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs_dropped = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs_dropped = attention_probs_dropped * head_mask\n\n        context_layer = torch.matmul(attention_probs_dropped, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (\n            (context_layer, attention_probs) if output_attentions else (context_layer,)\n        )\n\n        outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config, twin=False, merge=False):\n        super().__init__()\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        if twin:\n            self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)\n            self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)\n        else:\n            self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if merge:\n            self.act = ACT2FN[config.hidden_act]\n            self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)\n            self.merge = True\n        else:\n            self.merge = False\n\n    def forward(self, hidden_states, input_tensor):\n        if type(hidden_states) == list:\n            hidden_states0 = self.dense0(hidden_states[0])\n            hidden_states1 = self.dense1(hidden_states[1])\n            if self.merge:\n                # hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))\n                hidden_states = self.merge_layer(\n                    torch.cat([hidden_states0, hidden_states1], dim=-1)\n                )\n            else:\n                hidden_states = (hidden_states0 + hidden_states1) / 2\n        else:\n            hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, is_cross_attention=False, layer_num=-1):\n        super().__init__()\n        if is_cross_attention:\n            self.self0 = BertSelfAttention(config, is_cross_attention)\n            self.self1 = BertSelfAttention(config, is_cross_attention)\n        else:\n            self.self = BertSelfAttention(config, is_cross_attention)\n        self.output = BertSelfOutput(\n            config,\n            twin=is_cross_attention,\n            merge=(is_cross_attention and layer_num >= 6),\n        )\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads,\n            self.self.num_attention_heads,\n            self.self.attention_head_size,\n            self.pruned_heads,\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = (\n            self.self.attention_head_size * self.self.num_attention_heads\n        )\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        if type(encoder_hidden_states) == list:\n            self_outputs0 = self.self0(\n                hidden_states,\n                attention_mask,\n                head_mask,\n                encoder_hidden_states[0],\n                encoder_attention_mask[0],\n                past_key_value,\n                output_attentions,\n            )\n            self_outputs1 = self.self1(\n                hidden_states,\n                attention_mask,\n                head_mask,\n                encoder_hidden_states[1],\n                encoder_attention_mask[1],\n                past_key_value,\n                output_attentions,\n            )\n            attention_output = self.output(\n                [self_outputs0[0], self_outputs1[0]], hidden_states\n            )\n\n            outputs = (attention_output,) + self_outputs0[\n                1:\n            ]  # add attentions if we output them\n        else:\n            self_outputs = self.self(\n                hidden_states,\n                attention_mask,\n                head_mask,\n                encoder_hidden_states,\n                encoder_attention_mask,\n                past_key_value,\n                output_attentions,\n            )\n            attention_output = self.output(self_outputs[0], hidden_states)\n            outputs = (attention_output,) + self_outputs[\n                1:\n            ]  # add attentions if we output them\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_num):\n        super().__init__()\n        self.config = config\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n        self.layer_num = layer_num\n        if self.config.add_cross_attention:\n            self.crossattention = BertAttention(\n                config,\n                is_cross_attention=self.config.add_cross_attention,\n                layer_num=layer_num,\n            )\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        mode=None,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = (\n            past_key_value[:2] if past_key_value is not None else None\n        )\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n\n        outputs = self_attention_outputs[1:-1]\n        present_key_value = self_attention_outputs[-1]\n\n        if mode == \"multimodal\":\n            assert (\n                encoder_hidden_states is not None\n            ), \"encoder_hidden_states must be given for cross-attention layers\"\n            cross_attention_outputs = self.crossattention(\n                attention_output,\n                attention_mask,\n                head_mask,\n                encoder_hidden_states,\n                encoder_attention_mask,\n                output_attentions=output_attentions,\n            )\n            attention_output = cross_attention_outputs[0]\n            outputs = (\n                outputs + cross_attention_outputs[1:-1]\n            )  # add cross attentions if we output attention weights\n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk,\n            self.chunk_size_feed_forward,\n            self.seq_len_dim,\n            attention_output,\n        )\n        outputs = (layer_output,) + outputs\n\n        outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList(\n            [BertLayer(config, i) for i in range(config.num_hidden_layers)]\n        )\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        mode=\"multimodal\",\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = (\n            () if output_attentions and self.config.add_cross_attention else None\n        )\n\n        next_decoder_cache = () if use_cache else None\n\n        for i in range(self.config.num_hidden_layers):\n            layer_module = self.layer[i]\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warn(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs, past_key_value, output_attentions)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    mode=mode,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    mode=mode,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertOnlyMLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n\n    def forward(self, sequence_output):\n        prediction_scores = self.predictions(sequence_output)\n        return prediction_scores\n\n\nclass BertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = BertConfig\n    base_model_prefix = \"bert\"\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n\nclass BertModel(BertPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config, add_pooling_layer=True):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n\n        self.encoder = BertEncoder(config)\n\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def get_extended_attention_mask(\n        self,\n        attention_mask: Tensor,\n        input_shape: Tuple[int],\n        device: device,\n        is_decoder: bool,\n    ) -> Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple[int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if attention_mask.dim() == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif attention_mask.dim() == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if is_decoder:\n                batch_size, seq_length = input_shape\n\n                seq_ids = torch.arange(seq_length, device=device)\n                causal_mask = (\n                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)\n                    <= seq_ids[None, :, None]\n                )\n                # in case past_key_values are used we need to add a prefix ones mask to the causal mask\n                # causal and attention masks must have same type with pytorch version < 1.3\n                causal_mask = causal_mask.to(attention_mask.dtype)\n\n                if causal_mask.shape[1] < attention_mask.shape[1]:\n                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                    causal_mask = torch.cat(\n                        [\n                            torch.ones(\n                                (batch_size, seq_length, prefix_seq_len),\n                                device=device,\n                                dtype=causal_mask.dtype,\n                            ),\n                            causal_mask,\n                        ],\n                        axis=-1,\n                    )\n\n                extended_attention_mask = (\n                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n                )\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                \"Wrong shape for input_ids (shape {}) or attention_mask (shape {})\".format(\n                    input_shape, attention_mask.shape\n                )\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.to(\n            dtype=self.dtype\n        )  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n        mode=\"multimodal\",\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if is_decoder:\n            use_cache = use_cache if use_cache is not None else self.config.use_cache\n        else:\n            use_cache = False\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            batch_size, seq_length = input_shape\n            device = input_ids.device\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = inputs_embeds.device\n        elif encoder_embeds is not None:\n            input_shape = encoder_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = encoder_embeds.device\n        else:\n            raise ValueError(\n                \"You have to specify either input_ids or inputs_embeds or encoder_embeds\"\n            )\n\n        # past_key_values_length\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        )\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                ((batch_size, seq_length + past_key_values_length)), device=device\n            )\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(\n            attention_mask, input_shape, device, is_decoder\n        )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if encoder_hidden_states is not None:\n            if type(encoder_hidden_states) == list:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[\n                    0\n                ].size()\n            else:\n                (\n                    encoder_batch_size,\n                    encoder_sequence_length,\n                    _,\n                ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n\n            if type(encoder_attention_mask) == list:\n                encoder_extended_attention_mask = [\n                    self.invert_attention_mask(mask) for mask in encoder_attention_mask\n                ]\n            elif encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n            else:\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if encoder_embeds is None:\n            embedding_output = self.embeddings(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                inputs_embeds=inputs_embeds,\n                past_key_values_length=past_key_values_length,\n            )\n        else:\n            embedding_output = encoder_embeds\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            mode=mode,\n        )\n        sequence_output = encoder_outputs[0]\n        pooled_output = (\n            self.pooler(sequence_output) if self.pooler is not None else None\n        )\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n"
  },
  {
    "path": "lavis/models/clip_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\n\"\"\" OpenAI pretrained model functions\nAdapted from https://github.com/mlfoundations/open_clip and https://github.com/openai/CLIP.\n\nOriginally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\n"
  },
  {
    "path": "lavis/models/clip_models/clip_outputs.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_outputs import ModelOutput\n\n\n@dataclass\nclass ClipOutputFeatures(ModelOutput):\n    \"\"\"\n    Data class of features from AlbefFeatureExtractor.\n\n    Args:\n        image_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional`\n        image_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional`\n        text_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional`\n        text_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional`\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    image_embeds_proj: Optional[torch.FloatTensor] = None\n\n    text_embeds: Optional[torch.FloatTensor] = None\n    text_embeds_proj: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass ClipOutput(ModelOutput):\n    intermediate_output: Optional[ClipOutputFeatures] = None\n\n    logit_scale_exp: Optional[torch.FloatTensor] = None\n\n    loss: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "lavis/models/clip_models/loss.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport torch\nimport torch.distributed.nn\nfrom torch import distributed as dist, nn as nn\nfrom torch.nn import functional as F\n\ntry:\n    import horovod.torch as hvd\nexcept ImportError:\n    hvd = None\n\n\ndef gather_features(\n    image_features,\n    text_features,\n    local_loss=False,\n    gather_with_grad=False,\n    rank=0,\n    world_size=1,\n    use_horovod=False,\n):\n    if use_horovod:\n        assert hvd is not None, \"Please install horovod\"\n        if gather_with_grad:\n            all_image_features = hvd.allgather(image_features)\n            all_text_features = hvd.allgather(text_features)\n        else:\n            with torch.no_grad():\n                all_image_features = hvd.allgather(image_features)\n                all_text_features = hvd.allgather(text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features = list(\n                    all_image_features.chunk(world_size, dim=0)\n                )\n                gathered_text_features = list(\n                    all_text_features.chunk(world_size, dim=0)\n                )\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n                all_image_features = torch.cat(gathered_image_features, dim=0)\n                all_text_features = torch.cat(gathered_text_features, dim=0)\n    else:\n        # We gather tensors from all gpus\n        if gather_with_grad:\n            all_image_features = torch.cat(\n                torch.distributed.nn.all_gather(image_features), dim=0\n            )\n            all_text_features = torch.cat(\n                torch.distributed.nn.all_gather(text_features), dim=0\n            )\n        else:\n            gathered_image_features = [\n                torch.zeros_like(image_features) for _ in range(world_size)\n            ]\n            gathered_text_features = [\n                torch.zeros_like(text_features) for _ in range(world_size)\n            ]\n            dist.all_gather(gathered_image_features, image_features)\n            dist.all_gather(gathered_text_features, text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n            all_image_features = torch.cat(gathered_image_features, dim=0)\n            all_text_features = torch.cat(gathered_text_features, dim=0)\n\n    return all_image_features, all_text_features\n\n\nclass ClipLoss(nn.Module):\n    def __init__(\n        self,\n        local_loss=False,\n        gather_with_grad=False,\n        cache_labels=False,\n        rank=0,\n        world_size=1,\n        use_horovod=False,\n    ):\n        super().__init__()\n        self.local_loss = local_loss\n        self.gather_with_grad = gather_with_grad\n        self.cache_labels = cache_labels\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n        if self.world_size > 1:\n            all_image_features, all_text_features = gather_features(\n                image_features,\n                text_features,\n                self.local_loss,\n                self.gather_with_grad,\n                self.rank,\n                self.world_size,\n                self.use_horovod,\n            )\n\n            if self.local_loss:\n                logits_per_image = logit_scale * image_features @ all_text_features.T\n                logits_per_text = logit_scale * text_features @ all_image_features.T\n            else:\n                logits_per_image = (\n                    logit_scale * all_image_features @ all_text_features.T\n                )\n                logits_per_text = logits_per_image.T\n        else:\n            logits_per_image = logit_scale * image_features @ text_features.T\n            logits_per_text = logit_scale * text_features @ image_features.T\n\n        # calculated ground-truth and cache if enabled\n        num_logits = logits_per_image.shape[0]\n        if self.prev_num_logits != num_logits or device not in self.labels:\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\n            if self.world_size > 1 and self.local_loss:\n                labels = labels + num_logits * self.rank\n            if self.cache_labels:\n                self.labels[device] = labels\n                self.prev_num_logits = num_logits\n        else:\n            labels = self.labels[device]\n\n        total_loss = (\n            F.cross_entropy(logits_per_image, labels)\n            + F.cross_entropy(logits_per_text, labels)\n        ) / 2\n        return total_loss\n"
  },
  {
    "path": "lavis/models/clip_models/model.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\n\"\"\" CLIP Model\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\n\nimport datetime\nimport json\nimport logging\nimport os\nimport re\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import get_abs_path\nfrom lavis.models.base_model import BaseModel\nfrom lavis.models.clip_models.clip_outputs import ClipOutput, ClipOutputFeatures\nfrom lavis.models.clip_models.timm_model import TimmModel\nfrom lavis.models.clip_models.transform import image_transform\nfrom lavis.models.clip_models.utils import freeze_batch_norm_2d\nfrom lavis.tasks.multimodal_classification import MultimodalClassificationTask\nfrom torch import nn\n\nfrom .pretrained import (\n    download_pretrained,\n    get_pretrained_url,\n    list_pretrained_tag_models,\n)\n\n_MODEL_CONFIG_PATHS = [Path(__file__).parent.parent.parent / f\"configs/models/clip/\"]\n_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(\n                OrderedDict(\n                    [\n                        (\"-1\", nn.AvgPool2d(stride)),\n                        (\n                            \"0\",\n                            nn.Conv2d(\n                                inplanes,\n                                planes * self.expansion,\n                                1,\n                                stride=1,\n                                bias=False,\n                            ),\n                        ),\n                        (\"1\", nn.BatchNorm2d(planes * self.expansion)),\n                    ]\n                )\n            )\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.relu(self.bn1(self.conv1(x)))\n        out = self.relu(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(\n        self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5\n        )\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(\n            2, 0, 1\n        )  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x,\n            key=x,\n            value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat(\n                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]\n            ),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False,\n        )\n\n        return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n\n    def __init__(self, layers, output_dim, heads, image_size=224, width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.image_size = image_size\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(\n            3, width // 2, kernel_size=3, stride=2, padding=1, bias=False\n        )\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.conv2 = nn.Conv2d(\n            width // 2, width // 2, kernel_size=3, padding=1, bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.avgpool = nn.AvgPool2d(2)\n        self.relu = nn.ReLU(inplace=True)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\n\n        self.init_parameters()\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def init_parameters(self):\n        if self.attnpool is not None:\n            std = self.attnpool.c_proj.in_features**-0.5\n            nn.init.normal_(self.attnpool.q_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.k_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.v_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.c_proj.weight, std=std)\n\n        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\n            for name, param in resnet_block.named_parameters():\n                if name.endswith(\"bn3.weight\"):\n                    nn.init.zeros_(param)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        assert (\n            unlocked_groups == 0\n        ), \"partial locking not currently supported for this model\"\n        for param in self.parameters():\n            param.requires_grad = False\n        if freeze_bn_stats:\n            freeze_batch_norm_2d(self)\n\n    def stem(self, x):\n        for conv, bn in [\n            (self.conv1, self.bn1),\n            (self.conv2, self.bn2),\n            (self.conv3, self.bn3),\n        ]:\n            x = self.relu(bn(conv(x)))\n        x = self.avgpool(x)\n        return x\n\n    def forward(self, x):\n        x = self.stem(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.attnpool(x)\n\n        return x\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        return x.to(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n                    (\"gelu\", act_layer()),\n                    (\"c_proj\", nn.Linear(d_model * 4, d_model)),\n                ]\n            )\n        )\n        self.ln_2 = LayerNorm(d_model)\n\n    def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU\n    ):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(width, heads, act_layer=act_layer)\n                for _ in range(layers)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        for r in self.resblocks:\n            x = r(x, attn_mask=attn_mask)\n        return x\n\n\nclass VisualTransformer(nn.Module):\n    def __init__(\n        self,\n        image_size: int,\n        patch_size: int,\n        width: int,\n        layers: int,\n        heads: int,\n        output_dim: int,\n        act_layer: Callable = nn.GELU,\n    ):\n        super().__init__()\n        self.image_size = image_size\n        self.output_dim = output_dim\n        self.conv1 = nn.Conv2d(\n            in_channels=3,\n            out_channels=width,\n            kernel_size=patch_size,\n            stride=patch_size,\n            bias=False,\n        )\n\n        scale = width**-0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        self.positional_embedding = nn.Parameter(\n            scale * torch.randn((image_size // patch_size) ** 2 + 1, width)\n        )\n        self.ln_pre = LayerNorm(width)\n\n        self.transformer = Transformer(width, layers, heads, act_layer=act_layer)\n\n        self.ln_post = LayerNorm(width)\n        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        assert (\n            unlocked_groups == 0\n        ), \"partial locking not currently supported for this model\"\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n        x = torch.cat(\n            [\n                self.class_embedding.to(x.dtype)\n                + torch.zeros(\n                    x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device\n                ),\n                x,\n            ],\n            dim=1,\n        )  # shape = [*, grid ** 2 + 1, width]\n        x = x + self.positional_embedding.to(x.dtype)\n        x = self.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        x = self.ln_post(x[:, 0, :])\n\n        if self.proj is not None:\n            x = x @ self.proj\n\n        return x\n\n\n@dataclass\nclass CLIPVisionCfg:\n    layers: Union[Tuple[int, int, int, int], int] = 12\n    width: int = 768\n    patch_size: int = 16\n    image_size: Union[Tuple[int, int], int] = 224\n    timm_model_name: str = (\n        None  # a valid model name overrides layers, width, patch_size\n    )\n    timm_model_pretrained: bool = (\n        False  # use (imagenet) pretrained weights for named model\n    )\n    timm_pool: str = (\n        \"avg\"  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n    )\n    timm_proj: str = (\n        \"linear\"  # linear projection for timm model output ('linear', 'mlp', '')\n    )\n\n\n@dataclass\nclass CLIPTextCfg:\n    context_length: int\n    vocab_size: int\n    width: int\n    heads: int\n    layers: int\n\n\n@registry.register_model(\"clip\")\n@registry.register_model(\"clip_feature_extractor\")\nclass CLIP(BaseModel):\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"ViT-B-32\": \"configs/models/clip_vit_base32.yaml\",\n        \"ViT-B-16\": \"configs/models/clip_vit_base16.yaml\",\n        \"ViT-L-14\": \"configs/models/clip_vit_large14.yaml\",\n        \"ViT-L-14-336\": \"configs/models/clip_vit_large14_336.yaml\",\n        \"RN50\": \"configs/models/clip_resnet50.yaml\",\n    }\n\n    def __init__(\n        self,\n        embed_dim: int,\n        vision_cfg: CLIPVisionCfg,\n        text_cfg: CLIPTextCfg,\n        quick_gelu: bool = False,\n    ):\n        from .tokenizer import tokenize\n\n        super().__init__()\n\n        self.tokenizer = tokenize\n        self._loss = None\n\n        if isinstance(vision_cfg, dict):\n            vision_cfg = CLIPVisionCfg(**vision_cfg)\n        if isinstance(text_cfg, dict):\n            text_cfg = CLIPTextCfg(**text_cfg)\n\n        self.context_length = text_cfg.context_length\n\n        # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\n        # memory efficient in recent PyTorch releases (>= 1.10).\n        # NOTE: timm models always use native GELU regardless of quick_gelu flag.\n        act_layer = QuickGELU if quick_gelu else nn.GELU\n\n        if vision_cfg.timm_model_name:\n            self.visual = TimmModel(\n                vision_cfg.timm_model_name,\n                pretrained=vision_cfg.timm_model_pretrained,\n                pool=vision_cfg.timm_pool,\n                proj=vision_cfg.timm_proj,\n                embed_dim=embed_dim,\n                image_size=vision_cfg.image_size,\n            )\n            act_layer = (\n                nn.GELU\n            )  # so that text transformer doesn't use QuickGELU w/ timm models\n        elif isinstance(vision_cfg.layers, (tuple, list)):\n            vision_heads = vision_cfg.width * 32 // 64\n            self.visual = ModifiedResNet(\n                layers=vision_cfg.layers,\n                output_dim=embed_dim,\n                heads=vision_heads,\n                image_size=vision_cfg.image_size,\n                width=vision_cfg.width,\n            )\n        else:\n            vision_heads = vision_cfg.width // 64\n            self.visual = VisualTransformer(\n                image_size=vision_cfg.image_size,\n                patch_size=vision_cfg.patch_size,\n                width=vision_cfg.width,\n                layers=vision_cfg.layers,\n                heads=vision_heads,\n                output_dim=embed_dim,\n                act_layer=act_layer,\n            )\n\n        self.transformer = Transformer(\n            width=text_cfg.width,\n            layers=text_cfg.layers,\n            heads=text_cfg.heads,\n            act_layer=act_layer,\n        )\n\n        self.vocab_size = text_cfg.vocab_size\n        self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)\n        self.positional_embedding = nn.Parameter(\n            torch.empty(self.context_length, text_cfg.width)\n        )\n        self.ln_final = LayerNorm(text_cfg.width)\n\n        self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n        self.register_buffer(\"attn_mask\", self.build_attention_mask(), persistent=False)\n\n        self.prompt_templates = openai_imagenet_template\n        self.classifier = None\n\n        self.init_parameters()\n\n    @property\n    def loss(self):\n        if self._loss is None:\n            from lavis.models.clip_models.loss import ClipLoss\n            from torch import distributed as dist\n\n            self._loss = ClipLoss(\n                world_size=dist.get_world_size(),\n                rank=dist.get_rank(),\n                local_loss=False,\n                gather_with_grad=False,\n                use_horovod=False,\n            )\n\n        return self._loss\n\n    def init_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n        nn.init.constant_(self.logit_scale, np.log(1 / 0.07))\n\n        if hasattr(self.visual, \"init_parameters\"):\n            self.visual.init_parameters()\n\n        proj_std = (self.transformer.width**-0.5) * (\n            (2 * self.transformer.layers) ** -0.5\n        )\n        attn_std = self.transformer.width**-0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(\n            unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats\n        )\n\n    def encode_image(self, image):\n        return self.visual(image)\n\n    def encode_text(self, text):\n        x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x, attn_mask=self.attn_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n    # def forward(self, image, text):\n    def forward(self, samples):\n        image = samples.get(\"image\")\n        text = samples.get(\"text_input\")\n\n        if text is not None:\n            text = self.tokenizer(text).to(self.device)\n\n        if image is None:\n            return self.encode_text(text)\n        elif text is None:\n            return self.encode_image(image)\n        image_embeds = self.encode_image(image)\n        image_features = F.normalize(image_embeds, dim=-1)\n\n        text_embeds = self.encode_text(text)\n        text_features = F.normalize(text_embeds, dim=-1)\n\n        loss = self.loss(image_features, text_features, self.logit_scale.exp())\n\n        # return image_features, text_features, self.logit_scale.exp()\n        # return {\"loss\": loss}\n        return ClipOutput(\n            intermediate_output=ClipOutputFeatures(\n                image_embeds=image_embeds,\n                image_embeds_proj=image_features,\n                text_embeds=text_embeds,\n                text_embeds_proj=text_features,\n            ),\n            loss=loss,\n            logit_scale_exp=self.logit_scale.exp(),\n        )\n\n    def extract_features(self, samples):\n        \"\"\"\n        Extract features from the model for samples.\n\n        Keys allowed are \"image\" and \"text_input\" in samples.\n        If either key is missing, the corresponding features are not extracted.\n\n        Args:\n            samples: dict of samples to extract features from.\n\n        Returns:\n            ClipOutputFeatures object with features for the samples.\n        \"\"\"\n        image = samples.get(\"image\")\n        text = samples.get(\"text_input\")\n\n        if text is not None:\n            text = self.tokenizer(text).to(self.device)\n\n        if image is None:\n            return self.encode_text(text)\n        elif text is None:\n            return self.encode_image(image)\n\n        image_embeds = self.encode_image(image)\n        image_features = F.normalize(image_embeds, dim=-1)\n\n        text_embeds = self.encode_text(text)\n        text_features = F.normalize(text_embeds, dim=-1)\n\n        return ClipOutputFeatures(\n            image_embeds=image_embeds,\n            image_embeds_proj=image_features,\n            text_embeds=text_embeds,\n            text_embeds_proj=text_features,\n        )\n\n    def predict(self, samples):\n        image = samples[\"image\"]\n        targets = samples[\"label\"]\n\n        image_features = self.encode_image(image)\n        image_features = F.normalize(image_features, dim=-1)\n\n        logits = 100.0 * image_features @ self.classifier\n\n        return {\"predictions\": logits, \"targets\": targets}\n\n    def before_evaluation(self, dataset, task_type, **kwargs):\n        if task_type == MultimodalClassificationTask:\n            self.classifier = self.zero_shot_classifier(\n                classnames=dataset.classnames,\n                templates=self.prompt_templates,\n            )\n\n    def zero_shot_classifier(self, classnames, templates):\n        with torch.no_grad():\n            zeroshot_weights = []\n            for classname in classnames:\n                texts = [\n                    template(classname) for template in templates\n                ]  # format with class\n                texts = self.tokenizer(texts).to(self.device)  # tokenize\n\n                class_embeddings = self.encode_text(texts)\n                class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)\n                class_embedding /= class_embedding.norm()\n                zeroshot_weights.append(class_embedding)\n            zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)\n        return zeroshot_weights\n\n    @classmethod\n    def default_config_path(cls, model_type=\"base\"):\n        model_type = \"ViT-B-32\" if model_type == \"base\" else model_type\n\n        assert (\n            model_type in cls.PRETRAINED_MODEL_CONFIG_DICT\n        ), \"Unknown model type {}. \\n Available types: {}\".format(\n            model_type, cls.PRETRAINED_MODEL_CONFIG_DICT.keys()\n        )\n        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        model_name = cfg.model_type\n        pretrained = cfg.pretrained\n\n        precision = cfg.get(\"precision\", \"fp32\")\n\n        return create_model(\n            model_name=model_name, pretrained=pretrained, precision=precision\n        )\n\n    def zero_shot_predict(self, image_path, categories):\n        assert isinstance(\n            categories, list\n        ), f\"categories must be a list, got {type(categories)}.\"\n        assert os.path.exists(image_path), f\"File {image_path} does not exist.\"\n\n        from lavis.processors.clip_processors import ClipImageEvalProcessor\n        from PIL import Image\n\n        image_preprocess = ClipImageEvalProcessor()\n        image = image_preprocess(Image.open(image_path)).unsqueeze(0)\n\n        text = self.tokenizer(categories)\n\n        with torch.no_grad():\n            image_features = self.encode_image(image)\n            text_features = self.encode_text(text)\n            image_features /= image_features.norm(dim=-1, keepdim=True)\n            text_features /= text_features.norm(dim=-1, keepdim=True)\n\n            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n\n            print(\"Label probs:\", text_probs)  # prints: [[1., 0., 0.]]\n\n    def compute_sim_matrix(self, data_loader, **kwargs):\n        logging.info(\"Computing features for evaluation...\")\n        start_time = time.time()\n\n        texts = data_loader.dataset.text\n        num_text = len(texts)\n        text_bs = 256\n        text_features = []\n\n        for i in range(0, num_text, text_bs):\n\n            text = texts[i : min(num_text, i + text_bs)]\n            text_input = self.tokenizer(text).to(self.device)\n\n            text_feat = self.encode_text(text_input)\n            text_feat = F.normalize(text_feat, dim=-1)\n\n            text_features.append(text_feat)\n\n        text_features = torch.cat(text_features, dim=0)\n\n        image_features = []\n        for samples in data_loader:\n            image = samples[\"image\"]\n\n            image = image.to(self.device)\n            image_feat = self.encode_image(image)\n            image_feat = F.normalize(image_feat, dim=-1)\n\n            image_features.append(image_feat)\n\n        image_features = torch.cat(image_features, dim=0)\n\n        sims_matrix_i2t = image_features @ text_features.t()\n        sims_matrix_t2i = sims_matrix_i2t.t()\n\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        logging.info(\"Evaluation time {}\".format(total_time_str))\n\n        return sims_matrix_i2t.cpu().numpy(), sims_matrix_t2i.cpu().numpy()\n\n\ndef convert_weights_to_fp16(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n        if isinstance(l, nn.MultiheadAttention):\n            for attr in [\n                *[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]],\n                \"in_proj_bias\",\n                \"bias_k\",\n                \"bias_v\",\n            ]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.half()\n\n        for name in [\"text_projection\", \"proj\"]:\n            if hasattr(l, name):\n                attr = getattr(l, name)\n                if attr is not None:\n                    attr.data = attr.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n\n\ndef build_model_from_openai_state_dict(state_dict: dict):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len(\n            [\n                k\n                for k in state_dict.keys()\n                if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")\n            ]\n        )\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round(\n            (state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5\n        )\n        image_size = vision_patch_size * grid_size\n    else:\n        counts: list = [\n            len(\n                set(\n                    k.split(\".\")[2]\n                    for k in state_dict\n                    if k.startswith(f\"visual.layer{b}\")\n                )\n            )\n            for b in [1, 2, 3, 4]\n        ]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round(\n            (state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5\n        )\n        vision_patch_size = None\n        assert (\n            output_width**2 + 1\n            == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n        )\n        image_size = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(\n        set(\n            k.split(\".\")[2]\n            for k in state_dict\n            if k.startswith(f\"transformer.resblocks\")\n        )\n    )\n\n    vision_cfg = CLIPVisionCfg(\n        layers=vision_layers,\n        width=vision_width,\n        patch_size=vision_patch_size,\n        image_size=image_size,\n    )\n    text_cfg = CLIPTextCfg(\n        context_length=context_length,\n        vocab_size=vocab_size,\n        width=transformer_width,\n        heads=transformer_heads,\n        layers=transformer_layers,\n    )\n    model = CLIP(\n        embed_dim,\n        vision_cfg=vision_cfg,\n        text_cfg=text_cfg,\n        quick_gelu=True,  # OpenAI models were trained with QuickGELU\n    )\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        state_dict.pop(key, None)\n\n    convert_weights_to_fp16(model)\n    model.load_state_dict(state_dict)\n    return model.eval()\n\n\ndef trace_model(model, batch_size=256, device=torch.device(\"cpu\")):\n    model.eval()\n    image_size = model.visual.image_size\n    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\n    example_text = torch.zeros(\n        (batch_size, model.context_length), dtype=torch.int, device=device\n    )\n    model = torch.jit.trace_module(\n        model,\n        inputs=dict(\n            forward=(example_images, example_text),\n            encode_text=(example_text,),\n            encode_image=(example_images,),\n        ),\n    )\n    model.visual.image_size = image_size\n    return\n\n\ndef _natural_key(string_):\n    return [int(s) if s.isdigit() else s for s in re.split(r\"(\\d+)\", string_.lower())]\n\n\ndef _rescan_model_configs():\n    global _MODEL_CONFIGS\n\n    config_ext = (\".json\",)\n    config_files = []\n    for config_path in _MODEL_CONFIG_PATHS:\n        if config_path.is_file() and config_path.suffix in config_ext:\n            config_files.append(config_path)\n        elif config_path.is_dir():\n            for ext in config_ext:\n                config_files.extend(config_path.glob(f\"*{ext}\"))\n\n    for cf in config_files:\n        with open(cf, \"r\") as f:\n            model_cfg = json.load(f)\n            if all(a in model_cfg for a in (\"embed_dim\", \"vision_cfg\", \"text_cfg\")):\n                _MODEL_CONFIGS[cf.stem] = model_cfg\n\n    _MODEL_CONFIGS = {\n        k: v\n        for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))\n    }\n\n\n_rescan_model_configs()  # initial populate of model config registry\n\n\ndef load_state_dict(checkpoint_path: str, map_location=\"cpu\"):\n    checkpoint = torch.load(checkpoint_path, map_location=map_location)\n    if isinstance(checkpoint, dict) and \"state_dict\" in checkpoint:\n        state_dict = checkpoint[\"state_dict\"]\n    else:\n        state_dict = checkpoint\n    if next(iter(state_dict.items()))[0].startswith(\"module\"):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n    return state_dict\n\n\ndef create_model(\n    model_name: str,\n    pretrained: str = \"\",\n    precision: str = \"fp32\",\n    device: torch.device = torch.device(\"cpu\"),\n    jit: bool = False,\n    force_quick_gelu: bool = False,\n    pretrained_image: bool = False,\n):\n    model_name = model_name.replace(\n        \"/\", \"-\"\n    )  # for callers using old naming with / in ViT names\n\n    if pretrained.lower() == \"openai\":\n        logging.info(f\"Loading pretrained {model_name} from OpenAI.\")\n        model = load_openai_model(model_name, device=device, jit=jit)\n        # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372\n        if precision == \"amp\" or precision == \"fp32\":\n            model = model.float()\n    else:\n        logging.info(f\"No pretrained weights loaded for {model_name} model.\")\n        if model_name in _MODEL_CONFIGS:\n            logging.info(f\"Loading {model_name} model config.\")\n            model_cfg = deepcopy(_MODEL_CONFIGS[model_name])\n        else:\n            logging.error(\n                f\"Model config for {model_name} not found; available models {list_models()}.\"\n            )\n            raise RuntimeError(f\"Model config for {model_name} not found.\")\n\n        if force_quick_gelu:\n            # override for use of QuickGELU on non-OpenAI transformer models\n            model_cfg[\"quick_gelu\"] = True\n\n        if pretrained_image:\n            if \"timm_model_name\" in model_cfg.get(\"vision_cfg\", {}):\n                # pretrained weight loading for timm models set via vision_cfg\n                model_cfg[\"vision_cfg\"][\"timm_model_pretrained\"] = True\n            else:\n                assert (\n                    False\n                ), \"pretrained image towers currently only supported for timm models\"\n\n        model = CLIP(**model_cfg)\n\n        if pretrained:\n            checkpoint_path = \"\"\n            url = get_pretrained_url(model_name, pretrained)\n            if url:\n                checkpoint_path = download_pretrained(url)\n            elif os.path.exists(pretrained):\n                checkpoint_path = pretrained\n\n            if checkpoint_path:\n                logging.info(f\"Loading pretrained {model_name} weights ({pretrained}).\")\n                model.load_state_dict(load_state_dict(checkpoint_path))\n            else:\n                logging.warning(\n                    f\"Pretrained weights ({pretrained}) not found for model {model_name}.\"\n                )\n                raise RuntimeError(\n                    f\"Pretrained weights ({pretrained}) not found for model {model_name}.\"\n                )\n\n        model.to(device=device)\n        if precision == \"fp16\":\n            assert device.type != \"cpu\"\n            convert_weights_to_fp16(model)\n\n        if jit:\n            model = torch.jit.script(model)\n\n    return model\n\n\ndef create_model_and_transforms(\n    model_name: str,\n    pretrained: str = \"\",\n    precision: str = \"fp32\",\n    device: torch.device = torch.device(\"cpu\"),\n    jit: bool = False,\n    force_quick_gelu: bool = False,\n    pretrained_image: bool = False,\n):\n    model = create_model(\n        model_name,\n        pretrained,\n        precision,\n        device,\n        jit,\n        force_quick_gelu=force_quick_gelu,\n        pretrained_image=pretrained_image,\n    )\n    preprocess_train = image_transform(model.visual.image_size, is_train=True)\n    preprocess_val = image_transform(model.visual.image_size, is_train=False)\n    return model, preprocess_train, preprocess_val\n\n\ndef list_models():\n    \"\"\"enumerate available model architectures based on config files\"\"\"\n    return list(_MODEL_CONFIGS.keys())\n\n\ndef add_model_config(path):\n    \"\"\"add model config path or file and update registry\"\"\"\n    if not isinstance(path, Path):\n        path = Path(path)\n    _MODEL_CONFIG_PATHS.append(path)\n    _rescan_model_configs()\n\n\ndef list_openai_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list_pretrained_tag_models(\"openai\")\n\n\ndef load_openai_model(\n    name: str,\n    device: Union[str, torch.device] = \"cuda\" if torch.cuda.is_available() else \"cpu\",\n    jit=True,\n):\n    \"\"\"Load a CLIP model\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n    device : Union[str, torch.device]\n        The device to put the loaded model\n    jit : bool\n        Whether to load the optimized JIT model (default) or more hackable non-JIT model.\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if get_pretrained_url(name, \"openai\"):\n        model_path = download_pretrained(get_pretrained_url(name, \"openai\"))\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(\n            f\"Model {name} not found; available models = {list_openai_models()}\"\n        )\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=device if jit else \"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        if jit:\n            warnings.warn(\n                f\"File {model_path} is not a JIT archive. Loading as a state dict instead\"\n            )\n            jit = False\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    if not jit:\n        try:\n            model = build_model_from_openai_state_dict(\n                state_dict or model.state_dict()\n            ).to(device)\n        except KeyError:\n            sd = {k[7:]: v for k, v in state_dict[\"state_dict\"].items()}\n            model = build_model_from_openai_state_dict(sd).to(device)\n\n        if str(device) == \"cpu\":\n            model.float()\n        return model\n\n    # patch the device names\n    device_holder = torch.jit.trace(\n        lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]\n    )\n    device_node = [\n        n\n        for n in device_holder.graph.findAllNodes(\"prim::Constant\")\n        if \"Device\" in repr(n)\n    ][-1]\n\n    def patch_device(module):\n        try:\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\n        except RuntimeError:\n            graphs = []\n\n        if hasattr(module, \"forward1\"):\n            graphs.append(module.forward1.graph)\n\n        for graph in graphs:\n            for node in graph.findAllNodes(\"prim::Constant\"):\n                if \"value\" in node.attributeNames() and str(node[\"value\"]).startswith(\n                    \"cuda\"\n                ):\n                    node.copyAttributes(device_node)\n\n    model.apply(patch_device)\n    patch_device(model.encode_image)\n    patch_device(model.encode_text)\n\n    # patch dtype to float32 on CPU\n    if str(device) == \"cpu\":\n        float_holder = torch.jit.trace(\n            lambda: torch.ones([]).float(), example_inputs=[]\n        )\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\n        float_node = float_input.node()\n\n        def patch_float(module):\n            try:\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\n            except RuntimeError:\n                graphs = []\n\n            if hasattr(module, \"forward1\"):\n                graphs.append(module.forward1.graph)\n\n            for graph in graphs:\n                for node in graph.findAllNodes(\"aten::to\"):\n                    inputs = list(node.inputs())\n                    for i in [\n                        1,\n                        2,\n                    ]:  # dtype can be the second or third argument to aten::to()\n                        if inputs[i].node()[\"value\"] == 5:\n                            inputs[i].node().copyAttributes(float_node)\n\n        model.apply(patch_float)\n        patch_float(model.encode_image)\n        patch_float(model.encode_text)\n        model.float()\n\n    # ensure image_size attr available at consistent location for both jit and non-jit\n    model.visual.image_size = model.input_resolution.item()\n    return model\n\n\nopenai_imagenet_template = [\n    lambda c: f\"a bad photo of a {c}.\",\n    lambda c: f\"a photo of many {c}.\",\n    lambda c: f\"a sculpture of a {c}.\",\n    lambda c: f\"a photo of the hard to see {c}.\",\n    lambda c: f\"a low resolution photo of the {c}.\",\n    lambda c: f\"a rendering of a {c}.\",\n    lambda c: f\"graffiti of a {c}.\",\n    lambda c: f\"a bad photo of the {c}.\",\n    lambda c: f\"a cropped photo of the {c}.\",\n    lambda c: f\"a tattoo of a {c}.\",\n    lambda c: f\"the embroidered {c}.\",\n    lambda c: f\"a photo of a hard to see {c}.\",\n    lambda c: f\"a bright photo of a {c}.\",\n    lambda c: f\"a photo of a clean {c}.\",\n    lambda c: f\"a photo of a dirty {c}.\",\n    lambda c: f\"a dark photo of the {c}.\",\n    lambda c: f\"a drawing of a {c}.\",\n    lambda c: f\"a photo of my {c}.\",\n    lambda c: f\"the plastic {c}.\",\n    lambda c: f\"a photo of the cool {c}.\",\n    lambda c: f\"a close-up photo of a {c}.\",\n    lambda c: f\"a black and white photo of the {c}.\",\n    lambda c: f\"a painting of the {c}.\",\n    lambda c: f\"a painting of a {c}.\",\n    lambda c: f\"a pixelated photo of the {c}.\",\n    lambda c: f\"a sculpture of the {c}.\",\n    lambda c: f\"a bright photo of the {c}.\",\n    lambda c: f\"a cropped photo of a {c}.\",\n    lambda c: f\"a plastic {c}.\",\n    lambda c: f\"a photo of the dirty {c}.\",\n    lambda c: f\"a jpeg corrupted photo of a {c}.\",\n    lambda c: f\"a blurry photo of the {c}.\",\n    lambda c: f\"a photo of the {c}.\",\n    lambda c: f\"a good photo of the {c}.\",\n    lambda c: f\"a rendering of the {c}.\",\n    lambda c: f\"a {c} in a video game.\",\n    lambda c: f\"a photo of one {c}.\",\n    lambda c: f\"a doodle of a {c}.\",\n    lambda c: f\"a close-up photo of the {c}.\",\n    lambda c: f\"a photo of a {c}.\",\n    lambda c: f\"the origami {c}.\",\n    lambda c: f\"the {c} in a video game.\",\n    lambda c: f\"a sketch of a {c}.\",\n    lambda c: f\"a doodle of the {c}.\",\n    lambda c: f\"a origami {c}.\",\n    lambda c: f\"a low resolution photo of a {c}.\",\n    lambda c: f\"the toy {c}.\",\n    lambda c: f\"a rendition of the {c}.\",\n    lambda c: f\"a photo of the clean {c}.\",\n    lambda c: f\"a photo of a large {c}.\",\n    lambda c: f\"a rendition of a {c}.\",\n    lambda c: f\"a photo of a nice {c}.\",\n    lambda c: f\"a photo of a weird {c}.\",\n    lambda c: f\"a blurry photo of a {c}.\",\n    lambda c: f\"a cartoon {c}.\",\n    lambda c: f\"art of a {c}.\",\n    lambda c: f\"a sketch of the {c}.\",\n    lambda c: f\"a embroidered {c}.\",\n    lambda c: f\"a pixelated photo of a {c}.\",\n    lambda c: f\"itap of the {c}.\",\n    lambda c: f\"a jpeg corrupted photo of the {c}.\",\n    lambda c: f\"a good photo of a {c}.\",\n    lambda c: f\"a plushie {c}.\",\n    lambda c: f\"a photo of the nice {c}.\",\n    lambda c: f\"a photo of the small {c}.\",\n    lambda c: f\"a photo of the weird {c}.\",\n    lambda c: f\"the cartoon {c}.\",\n    lambda c: f\"art of the {c}.\",\n    lambda c: f\"a drawing of the {c}.\",\n    lambda c: f\"a photo of the large {c}.\",\n    lambda c: f\"a black and white photo of a {c}.\",\n    lambda c: f\"the plushie {c}.\",\n    lambda c: f\"a dark photo of a {c}.\",\n    lambda c: f\"itap of a {c}.\",\n    lambda c: f\"graffiti of the {c}.\",\n    lambda c: f\"a toy {c}.\",\n    lambda c: f\"itap of my {c}.\",\n    lambda c: f\"a photo of a cool {c}.\",\n    lambda c: f\"a photo of a small {c}.\",\n    lambda c: f\"a tattoo of the {c}.\",\n]\n"
  },
  {
    "path": "lavis/models/clip_models/pretrained.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\nimport hashlib\nimport os\nimport urllib\nimport warnings\n\nfrom tqdm import tqdm\n\n_RN50 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n    yfcc15m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\",\n    cc12m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\",\n)\n\n_RN50_quickgelu = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n    yfcc15m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\",\n    cc12m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\",\n)\n\n_RN101 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n    yfcc15m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\",\n)\n\n_RN101_quickgelu = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n    yfcc15m=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\",\n)\n\n_RN50x4 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n)\n\n_RN50x16 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n)\n\n_RN50x64 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt\",\n)\n\n_VITB32 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n    laion400m_e31=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\",\n    laion400m_e32=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\",\n    laion400m_avg=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt\",\n)\n\n_VITB32_quickgelu = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n    laion400m_e31=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\",\n    laion400m_e32=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\",\n    laion400m_avg=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt\",\n)\n\n_VITB16 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n)\n\n_VITL14 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\",\n)\n\n_VITL14_336 = dict(\n    openai=\"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\"\n)\n\n_PRETRAINED = {\n    \"RN50\": _RN50,\n    \"RN50-quickgelu\": _RN50_quickgelu,\n    \"RN101\": _RN101,\n    \"RN101-quickgelu\": _RN101_quickgelu,\n    \"RN50x4\": _RN50x4,\n    \"RN50x16\": _RN50x16,\n    \"ViT-B-32\": _VITB32,\n    \"ViT-B-32-quickgelu\": _VITB32_quickgelu,\n    \"ViT-B-16\": _VITB16,\n    \"ViT-L-14\": _VITL14,\n    \"ViT-L-14-336\": _VITL14_336,\n}\n\n\ndef list_pretrained(as_str: bool = False):\n    \"\"\"returns list of pretrained models\n    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\n    \"\"\"\n    return [\n        \":\".join([k, t]) if as_str else (k, t)\n        for k in _PRETRAINED.keys()\n        for t in _PRETRAINED[k].keys()\n    ]\n\n\ndef list_pretrained_tag_models(tag: str):\n    \"\"\"return all models having the specified pretrain tag\"\"\"\n    models = []\n    for k in _PRETRAINED.keys():\n        if tag in _PRETRAINED[k]:\n            models.append(k)\n    return models\n\n\ndef list_pretrained_model_tags(model: str):\n    \"\"\"return all pretrain tags for the specified model architecture\"\"\"\n    tags = []\n    if model in _PRETRAINED:\n        tags.extend(_PRETRAINED[model].keys())\n    return tags\n\n\ndef get_pretrained_url(model: str, tag: str):\n    if model not in _PRETRAINED:\n        return \"\"\n    model_pretrained = _PRETRAINED[model]\n    tag = tag.lower()\n    if tag not in model_pretrained:\n        return \"\"\n    return model_pretrained[tag]\n\n\ndef download_pretrained(url: str, root: str = os.path.expanduser(\"~/.cache/clip\")):\n    os.makedirs(root, exist_ok=True)\n    filename = os.path.basename(url)\n\n    if \"openaipublic\" in url:\n        expected_sha256 = url.split(\"/\")[-2]\n    else:\n        expected_sha256 = \"\"\n\n    download_target = os.path.join(root, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if expected_sha256:\n            if (\n                hashlib.sha256(open(download_target, \"rb\").read()).hexdigest()\n                == expected_sha256\n            ):\n                return download_target\n            else:\n                warnings.warn(\n                    f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\"\n                )\n        else:\n            return download_target\n\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n        with tqdm(\n            total=int(source.info().get(\"Content-Length\")),\n            ncols=80,\n            unit=\"iB\",\n            unit_scale=True,\n        ) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if (\n        expected_sha256\n        and hashlib.sha256(open(download_target, \"rb\").read()).hexdigest()\n        != expected_sha256\n    ):\n        raise RuntimeError(\n            f\"Model has been downloaded but the SHA256 checksum does not not match\"\n        )\n\n    return download_target\n"
  },
  {
    "path": "lavis/models/clip_models/timm_model.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\n\"\"\" timm model adapter\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.\n\"\"\"\nimport math\nimport warnings\nfrom collections import OrderedDict\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch import nn as nn\n\ntry:\n    import timm\n    from timm.models.layers import Mlp, to_2tuple\n\n    # from timm.models.layers.attention_pool2d import RotAttentionPool2d\n    # from timm.models.layers.attention_pool2d import (\n    #     AttentionPool2d as AbsAttentionPool2d,\n    # )\n\nexcept ImportError as e:\n    timm = None\n\nfrom lavis.models.clip_models.utils import freeze_batch_norm_2d\n\n\nclass TimmModel(nn.Module):\n    \"\"\"timm model adapter\n    # FIXME this adapter is a work in progress, may change in ways that break weight compat\n    \"\"\"\n\n    def __init__(\n        self,\n        model_name,\n        embed_dim,\n        image_size=224,\n        pool=\"avg\",\n        proj=\"linear\",\n        drop=0.0,\n        pretrained=False,\n    ):\n        super().__init__()\n        if timm is None:\n            raise RuntimeError(\"Please `pip install timm` to use timm models.\")\n\n        self.image_size = to_2tuple(image_size)\n        self.trunk = timm.create_model(model_name, pretrained=pretrained)\n        feat_size = self.trunk.default_cfg.get(\"pool_size\", None)\n        feature_ndim = 1 if not feat_size else 2\n        if pool in (\"abs_attn\", \"rot_attn\"):\n            assert feature_ndim == 2\n            # if attn pooling used, remove both classifier and default pool\n            self.trunk.reset_classifier(0, global_pool=\"\")\n        else:\n            # reset global pool if pool config set, otherwise leave as network default\n            reset_kwargs = dict(global_pool=pool) if pool else {}\n            self.trunk.reset_classifier(0, **reset_kwargs)\n        prev_chs = self.trunk.num_features\n\n        head_layers = OrderedDict()\n        if pool == \"abs_attn\":\n            head_layers[\"pool\"] = AttentionPool2d(\n                prev_chs, feat_size=feat_size, out_features=embed_dim\n            )\n            prev_chs = embed_dim\n        elif pool == \"rot_attn\":\n            head_layers[\"pool\"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)\n            prev_chs = embed_dim\n        else:\n            assert proj, \"projection layer needed if non-attention pooling is used.\"\n\n        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used\n        if proj == \"linear\":\n            head_layers[\"drop\"] = nn.Dropout(drop)\n            head_layers[\"proj\"] = nn.Linear(prev_chs, embed_dim)\n        elif proj == \"mlp\":\n            head_layers[\"mlp\"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)\n\n        self.head = nn.Sequential(head_layers)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        \"\"\"lock modules\n        Args:\n            unlocked_groups (int): leave last n layer groups unlocked (default: 0)\n        \"\"\"\n        if not unlocked_groups:\n            # lock full model\n            for param in self.trunk.parameters():\n                param.requires_grad = False\n            if freeze_bn_stats:\n                freeze_batch_norm_2d(self.trunk)\n        else:\n            # NOTE: partial freeze requires latest timm (master) branch and is subject to change\n            try:\n                # FIXME import here until API stable and in an official release\n                from timm.models.helpers import group_modules, group_parameters\n            except ImportError:\n                raise RuntimeError(\n                    \"Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`\"\n                )\n            matcher = self.trunk.group_matcher()\n            gparams = group_parameters(self.trunk, matcher)\n            max_layer_id = max(gparams.keys())\n            max_layer_id = max_layer_id - unlocked_groups\n            for group_idx in range(max_layer_id + 1):\n                group = gparams[group_idx]\n                for param in group:\n                    self.trunk.get_parameter(param).requires_grad = False\n            if freeze_bn_stats:\n                gmodules = group_modules(self.trunk, matcher, reverse=True)\n                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}\n                freeze_batch_norm_2d(self.trunk, gmodules)\n\n    def forward(self, x):\n        x = self.trunk(x)\n        x = self.head(x)\n        return x\n\n\nclass RotAttentionPool2d(nn.Module):\n    \"\"\"Attention based 2D feature pooling w/ rotary (relative) pos embedding.\n    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.\n    Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.\n    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py\n    NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from\n    train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int = None,\n        embed_dim: int = None,\n        num_heads: int = 4,\n        qkv_bias: bool = True,\n    ):\n        super().__init__()\n        embed_dim = embed_dim or in_features\n        out_features = out_features or in_features\n        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(embed_dim, out_features)\n        self.num_heads = num_heads\n        assert embed_dim % num_heads == 0\n        self.head_dim = embed_dim // num_heads\n        self.scale = self.head_dim**-0.5\n        self.pos_embed = RotaryEmbedding(self.head_dim)\n\n        trunc_normal_(self.qkv.weight, std=in_features**-0.5)\n        nn.init.zeros_(self.qkv.bias)\n\n    def forward(self, x):\n        B, _, H, W = x.shape\n        N = H * W\n        x = x.reshape(B, -1, N).permute(0, 2, 1)\n\n        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)\n\n        x = (\n            self.qkv(x)\n            .reshape(B, N + 1, 3, self.num_heads, self.head_dim)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = x[0], x[1], x[2]\n\n        qc, q = q[:, :, :1], q[:, :, 1:]\n        sin_emb, cos_emb = self.pos_embed.get_embed((H, W))\n        q = apply_rot_embed(q, sin_emb, cos_emb)\n        q = torch.cat([qc, q], dim=2)\n\n        kc, k = k[:, :, :1], k[:, :, 1:]\n        k = apply_rot_embed(k, sin_emb, cos_emb)\n        k = torch.cat([kc, k], dim=2)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)\n        x = self.proj(x)\n        return x[:, 0]\n\n\nclass AttentionPool2d(nn.Module):\n    \"\"\"Attention based 2D feature pooling w/ learned (absolute) pos embedding.\n    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.\n    It was based on impl in CLIP by OpenAI\n    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py\n    NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        feat_size: Union[int, Tuple[int, int]],\n        out_features: int = None,\n        embed_dim: int = None,\n        num_heads: int = 4,\n        qkv_bias: bool = True,\n    ):\n        super().__init__()\n\n        embed_dim = embed_dim or in_features\n        out_features = out_features or in_features\n        assert embed_dim % num_heads == 0\n        self.feat_size = to_2tuple(feat_size)\n        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(embed_dim, out_features)\n        self.num_heads = num_heads\n        self.head_dim = embed_dim // num_heads\n        self.scale = self.head_dim**-0.5\n\n        spatial_dim = self.feat_size[0] * self.feat_size[1]\n        self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))\n        trunc_normal_(self.pos_embed, std=in_features**-0.5)\n        trunc_normal_(self.qkv.weight, std=in_features**-0.5)\n        nn.init.zeros_(self.qkv.bias)\n\n    def forward(self, x):\n        B, _, H, W = x.shape\n        N = H * W\n        assert self.feat_size[0] == H\n        assert self.feat_size[1] == W\n        x = x.reshape(B, -1, N).permute(0, 2, 1)\n        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)\n        x = x + self.pos_embed.unsqueeze(0).to(x.dtype)\n\n        x = (\n            self.qkv(x)\n            .reshape(B, N + 1, 3, self.num_heads, self.head_dim)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = x[0], x[1], x[2]\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)\n        x = self.proj(x)\n        return x[:, 0]\n\n\ndef pixel_freq_bands(\n    num_bands: int,\n    max_freq: float = 224.0,\n    linear_bands: bool = True,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n):\n    if linear_bands:\n        bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)\n    else:\n        bands = 2 ** torch.linspace(\n            0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device\n        )\n    return bands * torch.pi\n\n\ndef inv_freq_bands(\n    num_bands: int,\n    temperature: float = 100000.0,\n    step: int = 2,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    inv_freq = 1.0 / (\n        temperature\n        ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)\n    )\n    return inv_freq\n\n\ndef build_sincos2d_pos_embed(\n    feat_shape: List[int],\n    dim: int = 64,\n    temperature: float = 10000.0,\n    reverse_coord: bool = False,\n    interleave_sin_cos: bool = False,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        feat_shape:\n        dim:\n        temperature:\n        reverse_coord: stack grid order W, H instead of H, W\n        interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos\n        dtype:\n        device:\n    Returns:\n    \"\"\"\n    assert (\n        dim % 4 == 0\n    ), \"Embed dimension must be divisible by 4 for sin-cos 2D position embedding\"\n    pos_dim = dim // 4\n    bands = inv_freq_bands(\n        pos_dim, temperature=temperature, step=1, dtype=dtype, device=device\n    )\n\n    if reverse_coord:\n        feat_shape = feat_shape[::-1]  # stack W, H instead of H, W\n    grid = (\n        torch.stack(\n            torch.meshgrid(\n                [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]\n            )\n        )\n        .flatten(1)\n        .transpose(0, 1)\n    )\n    pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)\n    # FIXME add support for unflattened spatial dim?\n\n    stack_dim = (\n        2 if interleave_sin_cos else 1\n    )  # stack sin, cos, sin, cos  instead of sin sin cos cos\n    pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)\n    return pos_emb\n\n\ndef build_fourier_pos_embed(\n    feat_shape: List[int],\n    bands: Optional[torch.Tensor] = None,\n    num_bands: int = 64,\n    max_res: int = 224,\n    linear_bands: bool = False,\n    include_grid: bool = False,\n    concat_out: bool = True,\n    in_pixels: bool = True,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> List[torch.Tensor]:\n    if bands is None:\n        if in_pixels:\n            bands = pixel_freq_bands(\n                num_bands,\n                float(max_res),\n                linear_bands=linear_bands,\n                dtype=dtype,\n                device=device,\n            )\n        else:\n            bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)\n    else:\n        if device is None:\n            device = bands.device\n        if dtype is None:\n            dtype = bands.dtype\n\n    if in_pixels:\n        grid = torch.stack(\n            torch.meshgrid(\n                [\n                    torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=dtype)\n                    for s in feat_shape\n                ]\n            ),\n            dim=-1,\n        )\n    else:\n        grid = torch.stack(\n            torch.meshgrid(\n                [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]\n            ),\n            dim=-1,\n        )\n    grid = grid.unsqueeze(-1)\n    pos = grid * bands\n\n    pos_sin, pos_cos = pos.sin(), pos.cos()\n    out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)\n    # FIXME torchscript doesn't like multiple return types, probably need to always cat?\n    if concat_out:\n        out = torch.cat(out, dim=-1)\n    return out\n\n\nclass FourierEmbed(nn.Module):\n    def __init__(\n        self,\n        max_res: int = 224,\n        num_bands: int = 64,\n        concat_grid=True,\n        keep_spatial=False,\n    ):\n        super().__init__()\n        self.max_res = max_res\n        self.num_bands = num_bands\n        self.concat_grid = concat_grid\n        self.keep_spatial = keep_spatial\n        self.register_buffer(\n            \"bands\", pixel_freq_bands(max_res, num_bands), persistent=False\n        )\n\n    def forward(self, x):\n        B, C = x.shape[:2]\n        feat_shape = x.shape[2:]\n        emb = build_fourier_pos_embed(\n            feat_shape,\n            self.bands,\n            include_grid=self.concat_grid,\n            dtype=x.dtype,\n            device=x.device,\n        )\n        emb = emb.transpose(-1, -2).flatten(len(feat_shape))\n        batch_expand = (B,) + (-1,) * (x.ndim - 1)\n\n        # FIXME support nD\n        if self.keep_spatial:\n            x = torch.cat(\n                [x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1\n            )\n        else:\n            x = torch.cat(\n                [x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1\n            )\n            x = x.reshape(B, feat_shape.numel(), -1)\n\n        return x\n\n\ndef rot(x):\n    return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)\n\n\ndef apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):\n    return x * cos_emb + rot(x) * sin_emb\n\n\ndef apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):\n    if isinstance(x, torch.Tensor):\n        x = [x]\n    return [t * cos_emb + rot(t) * sin_emb for t in x]\n\n\ndef apply_rot_embed_split(x: torch.Tensor, emb):\n    split = emb.shape[-1] // 2\n    return x * emb[:, :split] + rot(x) * emb[:, split:]\n\n\ndef build_rotary_pos_embed(\n    feat_shape: List[int],\n    bands: Optional[torch.Tensor] = None,\n    dim: int = 64,\n    max_freq: float = 224,\n    linear_bands: bool = False,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n):\n    \"\"\"\n    NOTE: shape arg should include spatial dim only\n    \"\"\"\n    feat_shape = torch.Size(feat_shape)\n\n    sin_emb, cos_emb = build_fourier_pos_embed(\n        feat_shape,\n        bands=bands,\n        num_bands=dim // 4,\n        max_res=max_freq,\n        linear_bands=linear_bands,\n        concat_out=False,\n        device=device,\n        dtype=dtype,\n    )\n    N = feat_shape.numel()\n    sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)\n    cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)\n    return sin_emb, cos_emb\n\n\nclass RotaryEmbedding(nn.Module):\n    \"\"\"Rotary position embedding\n    NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not\n    been well tested, and will likely change. It will be moved to its own file.\n    The following impl/resources were referenced for this impl:\n    * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py\n    * https://blog.eleuther.ai/rotary-embeddings/\n    \"\"\"\n\n    def __init__(self, dim, max_res=224, linear_bands: bool = False):\n        super().__init__()\n        self.dim = dim\n        self.register_buffer(\n            \"bands\",\n            pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands),\n            persistent=False,\n        )\n\n    def get_embed(self, shape: List[int]):\n        return build_rotary_pos_embed(shape, self.bands)\n\n    def forward(self, x):\n        # assuming channel-first tensor where spatial dim are >= 2\n        sin_emb, cos_emb = self.get_embed(x.shape[2:])\n        return apply_rot_embed(x, sin_emb, cos_emb)\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.0))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n"
  },
  {
    "path": "lavis/models/clip_models/tokenizer.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\n\"\"\" CLIP tokenizer\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimport gzip\nimport html\nimport os\nfrom functools import lru_cache\nfrom typing import Union, List\n\nimport ftfy\nimport regex as re\nimport torch\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(\n        os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\"\n    )\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = (\n        list(range(ord(\"!\"), ord(\"~\") + 1))\n        + list(range(ord(\"¡\"), ord(\"¬\") + 1))\n        + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n    )\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8 + n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r\"\\s+\", \" \", text)\n    text = text.strip()\n    return text\n\n\nclass SimpleTokenizer(object):\n    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split(\"\\n\")\n        merges = merges[1 : 49152 - 256 - 2 + 1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v + \"</w>\" for v in vocab]\n        for merge in merges:\n            vocab.append(\"\".join(merge))\n        if not special_tokens:\n            special_tokens = [\"<start_of_text>\", \"<end_of_text>\"]\n        else:\n            special_tokens = [\"<start_of_text>\", \"<end_of_text>\"] + special_tokens\n        vocab.extend(special_tokens)\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {t: t for t in special_tokens}\n        special = \"|\".join(special_tokens)\n        self.pat = re.compile(\n            special + r\"\"\"|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\",\n            re.IGNORECASE,\n        )\n\n        self.vocab_size = len(self.encoder)\n        self.all_special_ids = [self.encoder[t] for t in special_tokens]\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + (token[-1] + \"</w>\",)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token + \"</w>\"\n\n        while True:\n            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float(\"inf\")))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:\n                    new_word.append(first + second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = \" \".join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = whitespace_clean(basic_clean(text)).lower()\n        for token in re.findall(self.pat, text):\n            token = \"\".join(self.byte_encoder[b] for b in token.encode(\"utf-8\"))\n            bpe_tokens.extend(\n                self.encoder[bpe_token] for bpe_token in self.bpe(token).split(\" \")\n            )\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = \"\".join([self.decoder[token] for token in tokens])\n        text = (\n            bytearray([self.byte_decoder[c] for c in text])\n            .decode(\"utf-8\", errors=\"replace\")\n            .replace(\"</w>\", \" \")\n        )\n        return text\n\n\n_tokenizer = SimpleTokenizer()\n\n\ndef tokenize(\n    texts: Union[str, List[str]], context_length: int = 77\n) -> torch.LongTensor:\n    \"\"\"\n    Returns the tokenized representation of given input string(s)\n    Parameters\n    ----------\n    texts : Union[str, List[str]]\n        An input string or a list of input strings to tokenize\n    context_length : int\n        The context length to use; all CLIP models use 77 as the context length\n    Returns\n    -------\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n    \"\"\"\n    if isinstance(texts, str):\n        texts = [texts]\n\n    sot_token = _tokenizer.encoder[\"<start_of_text>\"]\n    eot_token = _tokenizer.encoder[\"<end_of_text>\"]\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        if len(tokens) > context_length:\n            tokens = tokens[:context_length]  # Truncate\n        result[i, : len(tokens)] = torch.tensor(tokens)\n\n    return result\n"
  },
  {
    "path": "lavis/models/clip_models/transform.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\nfrom typing import Optional, Sequence, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms.functional as F\n\n\nfrom torchvision.transforms import (\n    Normalize,\n    Compose,\n    RandomResizedCrop,\n    InterpolationMode,\n    ToTensor,\n    Resize,\n    CenterCrop,\n)\n\n\nclass ResizeMaxSize(nn.Module):\n    def __init__(\n        self, max_size, interpolation=InterpolationMode.BICUBIC, fn=\"max\", fill=0\n    ):\n        super().__init__()\n        if not isinstance(max_size, int):\n            raise TypeError(f\"Size should be int. Got {type(max_size)}\")\n        self.max_size = max_size\n        self.interpolation = interpolation\n        self.fn = min if fn == \"min\" else min\n        self.fill = fill\n\n    def forward(self, img):\n        if isinstance(img, torch.Tensor):\n            height, width = img.shape[:2]\n        else:\n            width, height = img.size\n        scale = self.max_size / float(max(height, width))\n        if scale != 1.0:\n            new_size = tuple(round(dim * scale) for dim in (height, width))\n            img = F.resize(img, new_size, self.interpolation)\n            pad_h = self.max_size - new_size[0]\n            pad_w = self.max_size - new_size[1]\n            img = F.pad(\n                img,\n                padding=[\n                    pad_w // 2,\n                    pad_h // 2,\n                    pad_w - pad_w // 2,\n                    pad_h - pad_h // 2,\n                ],\n                fill=self.fill,\n            )\n        return img\n\n\ndef _convert_to_rgb(image):\n    return image.convert(\"RGB\")\n\n\ndef image_transform(\n    image_size: int,\n    is_train: bool,\n    mean: Optional[Tuple[float, ...]] = None,\n    std: Optional[Tuple[float, ...]] = None,\n    resize_longest_max: bool = False,\n    fill_color: int = 0,\n):\n    mean = mean or (0.48145466, 0.4578275, 0.40821073)  # OpenAI dataset mean\n    std = std or (0.26862954, 0.26130258, 0.27577711)  # OpenAI dataset std\n    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:\n        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge\n        image_size = image_size[0]\n\n    normalize = Normalize(mean=mean, std=std)\n    if is_train:\n        return Compose(\n            [\n                RandomResizedCrop(\n                    image_size,\n                    scale=(0.9, 1.0),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                _convert_to_rgb,\n                ToTensor(),\n                normalize,\n            ]\n        )\n    else:\n        if resize_longest_max:\n            transforms = [ResizeMaxSize(image_size, fill=fill_color)]\n        else:\n            transforms = [\n                Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n                CenterCrop(image_size),\n            ]\n        transforms.extend(\n            [\n                _convert_to_rgb,\n                ToTensor(),\n                normalize,\n            ]\n        )\n        return Compose(transforms)\n"
  },
  {
    "path": "lavis/models/clip_models/utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/mlfoundations/open_clip\n\"\"\"\n\nfrom torch import nn as nn\nfrom torchvision.ops.misc import FrozenBatchNorm2d\n\n\ndef freeze_batch_norm_2d(module, module_match={}, name=\"\"):\n    \"\"\"\n    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\n    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\n    returned. Otherwise, the module is walked recursively and submodules are converted in place.\n    Args:\n        module (torch.nn.Module): Any PyTorch module.\n        module_match (dict): Dictionary of full module names to freeze (all if empty)\n        name (str): Full module name (prefix)\n    Returns:\n        torch.nn.Module: Resulting module\n    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\n    \"\"\"\n    res = module\n    is_match = True\n    if module_match:\n        is_match = name in module_match\n    if is_match and isinstance(\n        module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)\n    ):\n        res = FrozenBatchNorm2d(module.num_features)\n        res.num_features = module.num_features\n        res.affine = module.affine\n        if module.affine:\n            res.weight.data = module.weight.data.clone().detach()\n            res.bias.data = module.bias.data.clone().detach()\n        res.running_mean.data = module.running_mean.data\n        res.running_var.data = module.running_var.data\n        res.eps = module.eps\n    else:\n        for child_name, child in module.named_children():\n            full_child_name = \".\".join([name, child_name]) if name else child_name\n            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\n            if new_child is not child:\n                res.add_module(child_name, new_child)\n    return res\n"
  },
  {
    "path": "lavis/models/eva_vit.py",
    "content": "# Based on EVA, BEIT, timm and DeiT code bases\n# https://github.com/baaivision/EVA\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/microsoft/unilm/tree/master/beit\n# https://github.com/facebookresearch/deit/\n# https://github.com/facebookresearch/dino\n# --------------------------------------------------------'\nimport math\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\nfrom lavis.common.dist_utils import download_cached_file\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),\n        **kwargs\n    }\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None and init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,\n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,\n                 use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):\n        super().__init__()\n        self.image_size = img_size\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            self.pos_embed = None\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n        self.use_checkpoint = use_checkpoint\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n#         self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)\n#         self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None\n#         self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n#         if isinstance(self.head, nn.Linear):\n#             trunc_normal_(self.head.weight, std=.02)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n#         if isinstance(self.head, nn.Linear):\n#             self.head.weight.data.mul_(init_scale)\n#             self.head.bias.data.mul_(init_scale)\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n        return x\n#         x = self.norm(x)\n\n#         if self.fc_norm is not None:\n#             t = x[:, 1:, :]\n#             return self.fc_norm(t.mean(1))\n#         else:\n#             return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n#         x = self.head(x)\n        return x\n\n    def get_intermediate_layers(self, x):\n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        features = []\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        for blk in self.blocks:\n            x = blk(x, rel_pos_bias)\n            features.append(x)\n\n        return features\n    \n    \ndef interpolate_pos_embed(model, checkpoint_model):\n    if 'pos_embed' in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model['pos_embed'].float()\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model['pos_embed'] = new_pos_embed\n            \n            \ndef convert_weights_to_fp16(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n#         if isinstance(l, (nn.MultiheadAttention, Attention)):\n#             for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n#                 tensor = getattr(l, attr)\n#                 if tensor is not None:\n#                     tensor.data = tensor.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n    \n    \ndef create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision=\"fp16\"):\n    model = VisionTransformer(\n        img_size=img_size,\n        patch_size=14,\n        use_mean_pooling=False,\n        embed_dim=1408,\n        depth=39,\n        num_heads=1408//88,\n        mlp_ratio=4.3637,\n        qkv_bias=True,\n        drop_path_rate=drop_path_rate,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        use_checkpoint=use_checkpoint,\n    )  \n    url = \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth\"\n    cached_file = download_cached_file(\n       url, check_hash=False, progress=True\n    )\n    state_dict = torch.load(cached_file, map_location=\"cpu\")    \n    interpolate_pos_embed(model,state_dict)\n    \n    incompatible_keys = model.load_state_dict(state_dict, strict=False)\n#     print(incompatible_keys)\n    \n    if precision == \"fp16\":\n#         model.to(\"cuda\") \n        convert_weights_to_fp16(model)\n    return model"
  },
  {
    "path": "lavis/models/gpt_models/gpt_dialogue.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import BaseModel\nfrom torch.nn import CrossEntropyLoss, MSELoss\nfrom transformers import GPT2LMHeadModel\nfrom transformers.modeling_outputs import CausalLMOutputWithCrossAttentions\n\n\n@registry.register_model(\"gpt_dialogue\")\nclass GPTDialogue(BaseModel, GPT2LMHeadModel):\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\"base\": \"configs/models/gpt_dialogue_base.yaml\"}\n\n    def __init__(self, config, len_video_ft=4224):\n\n        super().__init__(config)\n\n        self.video_ff = nn.Linear(len_video_ft, config.n_embd)\n        self.video_ff_out = nn.Linear(config.n_embd, len_video_ft)\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self,\n        samples,\n        past_key_values=None,\n        position_ids=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n\n        input_embs = self.transformer.wte(samples[\"input_ids\"])\n        video_embs = self.video_ff(samples[\"video_fts\"])\n        input_embs = torch.cat([video_embs, input_embs], dim=1)\n\n        transformer_outputs = self.transformer(\n            attention_mask=samples[\"attn_mask\"],\n            token_type_ids=samples[\"token_type_ids\"],\n            inputs_embeds=input_embs,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states)\n\n        loss = None\n        if samples[\"labels\"] is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = samples[\"labels\"][..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss(ignore_index=-1)\n            loss = loss_fct(\n                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n            )\n\n        if samples[\"video_fts\"] is not None:\n            len_video_fts = samples[\"video_fts\"].shape[1]\n            video_logits = self.video_ff_out(hidden_states[:, :len_video_fts, :])\n            # Shift so that tokens < n predict n\n            shift_logits = video_logits[..., :-1, :].contiguous()\n            shift_labels = samples[\"video_fts\"][..., 1:, :].contiguous()\n            # Flatten the tokens\n            loss_fct = MSELoss(reduction=\"mean\")\n            video_loss = loss_fct(shift_logits, shift_labels)\n\n            if loss is not None:\n                loss = loss + video_loss\n            else:\n                loss = video_loss\n\n        return CausalLMOutputWithCrossAttentions(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n            cross_attentions=transformer_outputs.cross_attentions,\n        )\n\n    @classmethod\n    def from_config(cls, cfg):\n        model = cls.__bases__[1].from_pretrained(\"gpt2\")\n        model.resize_token_embeddings(cfg[\"len_tokenizer\"])\n        return model\n"
  },
  {
    "path": "lavis/models/img2prompt_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\n\n\n\n"
  },
  {
    "path": "lavis/models/img2prompt_models/img2prompt_vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport random\n\nimport spacy\nimport torch\nimport torch.nn.functional as F\nfrom transformers import T5ForConditionalGeneration, T5Tokenizer\n\nfrom lavis.common.dist_utils import download_cached_file\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import BaseModel\nfrom lavis.models.blip_models.blip_image_text_matching import compute_gradcam\n\nopen_pos = [\"NOUN\", \"VERB\", \"ADJ\", \"ADV\", \"NUM\"]\n\n\n\n@registry.register_model(\"img2prompt_vqa\")\nclass Img2PromptVQA(BaseModel):\n    \"\"\"\n    Img2Prompt_VQA model consists of three submodels for zero-shot VQA:\n        1. Image-questioning matching model\n        2. Image captioning model\n        3. Large Language model\n\n    Supported model types:\n        - base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base)\n        - large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large)\n        - 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b)\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"img2prompt_vqa\", \"base\", is_eval=True)\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"base\": \"configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml\",\n    }\n\n    def __init__(\n        self,\n        image_question_matching_model,\n        image_captioning_model,\n        question_generation_model,\n        question_generation_tokenizer,\n        offload_model=False,\n    ):\n        super().__init__()\n\n        self.image_question_matching_model = image_question_matching_model\n        self.image_captioning_model = image_captioning_model\n        self.question_generation_model = question_generation_model\n        self.question_generation_tokenizer = question_generation_tokenizer\n        self.offload_model = offload_model\n        self.nlp = spacy.load(\"en_core_web_sm\")\n\n    def forward_itm(self, samples, block_num=7):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n            block_num (int): The index of cross-attention block for gradcam computation.\n\n        Returns:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n        \"\"\"\n        image = samples[\"image\"]\n        question = [text.strip(\"?\") for text in samples[\"text_input\"]]\n        tokenized_text = self.image_question_matching_model.tokenizer(\n            question, padding=\"longest\", truncation=True, return_tensors=\"pt\"\n        ).to(self.image_question_matching_model.device)\n        with torch.set_grad_enabled(True):\n            gradcams, _ = compute_gradcam(\n                model=self.image_question_matching_model,\n                visual_input=image,\n                text_input=question,\n                tokenized_text=tokenized_text,\n                block_num=block_num,\n            )\n\n        gradcams = [gradcam_[1] for gradcam_ in gradcams]\n        samples[\"gradcams\"] = torch.stack(gradcams).reshape(\n            samples[\"image\"].size(0), -1\n        )\n\n        return samples\n\n    def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head=\"itm\"):\n        # breakpoint()\n        encoder_input_ids = encoder_input_ids.clone()\n        encoder_input_ids = encoder_input_ids[:, self.prompt_length - 1 :]\n        text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long()\n\n        if match_head == \"itm\":\n            # encoder_input_ids = encoder_input_ids.clone()\n            encoder_input_ids[:, 0] = self.tokenizer.enc_token_id\n            output = self.text_encoder(\n                encoder_input_ids,\n                attention_mask=text_attention_mask,\n                encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts,\n                return_dict=True,\n            )\n            itm_output = self.itm_head(output.last_hidden_state[:, 0, :])\n            return itm_output  # , mask, token_length\n\n        elif match_head == \"itc\":\n            encoder_input_ids[:, 0] = self.tokenizer.cls_token_id\n            text_output = self.text_encoder(\n                encoder_input_ids,\n                attention_mask=text_attention_mask,\n                return_dict=True,\n                mode=\"text\",\n            )\n            image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)\n            text_feat = F.normalize(\n                self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1\n            )\n\n            sim = image_feat @ text_feat.t()\n            return sim\n\n    def forward_cap(\n        self,\n        samples,\n        cap_max_length=20,\n        cap_min_length=0,\n        top_p=1,\n        top_k=50,\n        repetition_penalty=1.0,\n        num_captions=100,\n        num_patches=20,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n            cap_max_length (int): The maximum length of the caption to be generated.\n            cap_min_length (int): The minimum length of the caption to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            top_k (float): The number of the highest probability tokens for top-k sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions generated for each image.\n            num_patches (int): Number of patches sampled for each image.\n\n        Returns:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n                - captions (nested list): A nested list of strings of total length batch_size * num_captions\n        \"\"\"\n        encoder_out = self.image_captioning_model.forward_encoder(samples)\n        captions = [[] for _ in range(encoder_out.size(0))]\n\n        min_num_captions = 0\n\n        while min_num_captions < num_captions:\n            encoder_out_samples = []\n            for i in range(num_captions):\n                patch_id = (\n                    torch.multinomial(\n                        samples[\"gradcams\"].to(self.image_captioning_model.device),\n                        num_patches,\n                    ).reshape(encoder_out.size(0), -1)\n                    + 1\n                )\n                patch_id = (\n                    patch_id.sort(dim=1)\n                    .values.unsqueeze(-1)\n                    .expand(-1, -1, encoder_out.size(2))\n                )\n                encoder_out_sample = torch.gather(encoder_out, 1, patch_id)\n                encoder_out_samples.append(encoder_out_sample)\n\n            stacked = torch.stack(encoder_out_samples, dim=1)\n            image_embeds = torch.flatten(\n                stacked, start_dim=0, end_dim=1\n            )  # (bsz*num_seq, num_patch, dim)\n\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n                self.image_captioning_model.device\n            )\n            model_kwargs = {\n                \"encoder_hidden_states\": image_embeds,\n                \"encoder_attention_mask\": image_atts,\n            }\n\n            prompt = [self.image_captioning_model.prompt] * image_embeds.size(0)\n            prompt = self.image_captioning_model.tokenizer(\n                prompt, return_tensors=\"pt\"\n            ).to(self.image_captioning_model.device)\n            prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id\n            prompt.input_ids = prompt.input_ids[:, :-1]\n\n            decoder_out = self.image_captioning_model.text_decoder.generate(\n                input_ids=prompt.input_ids,\n                max_length=cap_max_length,\n                min_length=cap_min_length,\n                do_sample=True,\n                top_p=top_p,\n                top_k=top_k,\n                num_return_sequences=1,\n                eos_token_id=self.image_captioning_model.tokenizer.sep_token_id,\n                pad_token_id=self.image_captioning_model.tokenizer.pad_token_id,\n                repetition_penalty=repetition_penalty,\n                **model_kwargs\n            )\n\n            itm_outputs = self.image_question_matching_model.itm_rank(\n                image_embeds, image_atts, encoder_input_ids=decoder_out\n            )  # caption filter\n\n            outputs = self.image_captioning_model.tokenizer.batch_decode(\n                decoder_out, skip_special_tokens=True\n            )\n\n            for counter, output in enumerate(outputs):\n                ind = counter // num_captions\n                if len(captions[ind]) < num_captions:\n                    caption = output[len(self.image_captioning_model.prompt) :]\n                    overlap_caption = [1 for caps in captions[ind] if caption in caps]\n                    # print(itm_outputs)\n                    if (\n                        len(overlap_caption) == 0 and itm_outputs[counter] >= 0.5\n                    ):  # image filter\n                        captions[ind].append(caption)\n\n            min_num_captions = min([len(i) for i in captions])\n\n        samples[\"captions\"] = captions\n\n        return samples\n\n    def answer_extraction(self, caption, num_question_generation=30):\n        cap_use = \"\"\n        # print(caption)\n        caption = caption\n        ans_to_cap_dict = {}\n        answers = []\n        for cap_idx, cap in enumerate(caption):\n            # print(cap)\n            cap_use += cap\n            cap = cap.strip().strip(\".\")\n            # print(cap)\n            cap = self.nlp(cap)\n            for token in cap:  # Noun /Verb/Adj//NUM\n                if token.pos_ in open_pos:\n                    if token.text.lower() not in ans_to_cap_dict:\n                        ans_to_cap_dict[token.text.lower()] = [cap_idx]\n                    else:\n                        if cap_idx not in ans_to_cap_dict[token.text.lower()]:\n                            ans_to_cap_dict[token.text.lower()].append(cap_idx)\n                    answers.append(token.text)\n            for ent in cap.ents:\n\n                if ent.text not in answers:\n                    if ent.text.lower() not in ans_to_cap_dict:\n                        ans_to_cap_dict[ent.text.lower()] = [cap_idx]\n                    else:\n                        if cap_idx not in ans_to_cap_dict[ent.text.lower()]:\n                            ans_to_cap_dict[ent.text.lower()].append(cap_idx)\n                    answers.append(ent.text)\n            for chunk in cap.noun_chunks:\n                if len(chunk.text.split()) < 4:\n                    if chunk.text.lower() not in ans_to_cap_dict:\n                        ans_to_cap_dict[chunk.text.lower()] = [cap_idx]\n                    else:\n                        if cap_idx not in ans_to_cap_dict[chunk.text.lower()]:\n                            ans_to_cap_dict[chunk.text.lower()].append(cap_idx)\n                    #                 print(chunk.text)\n                    answers.append(chunk.text)\n        answers = sorted(answers, key=answers.count, reverse=True)\n        real_answers = []\n        for i in answers:\n            i = i + \".\"\n            if i not in real_answers:\n                real_answers.append(i)\n\n        contexts_for_question_generation = []\n        answers = []\n        for ans in real_answers[\n            :num_question_generation\n        ]:  # Generate questions for 30 answers with max frequencies.\n            contexts_for_question_generation.append(\n                \"answer: %s  context: %s.\" % (ans, cap_use)\n            )\n            answers.append(ans)\n        contexts_for_question_generation.append(\n            \"answer: %s  context: %s.\" % (\"yes.\", cap_use)\n        )\n        answers.append(\"yes.\")\n        return contexts_for_question_generation, answers, ans_to_cap_dict\n\n    def forward_qa_generation(self, samples):\n        caption = samples[\"captions\"][0]\n        (\n            contexts_for_question_generation,\n            answers,\n            ans_to_cap_dict,\n        ) = self.answer_extraction(caption)\n        inputs = self.question_generation_tokenizer(\n            contexts_for_question_generation,\n            padding=\"longest\",\n            truncation=True,\n            max_length=2048,\n            return_tensors=\"pt\",\n        ).to(self.device)\n        question_size = inputs.input_ids.shape[0]\n        cur_b = 0\n        true_input_size = 10\n        outputs_list = []\n        while cur_b < question_size:\n            outputs = self.question_generation_model.generate(\n                input_ids=inputs.input_ids[cur_b : cur_b + true_input_size],\n                attention_mask=inputs.attention_mask[cur_b : cur_b + true_input_size],\n                num_beams=3,\n                max_length=30,\n            )\n            questions = self.question_generation_tokenizer.batch_decode(\n                outputs, skip_special_tokens=True\n            )\n            outputs_list += questions\n            cur_b += true_input_size\n        questions = outputs_list\n        samples[\"questions\"] = questions\n        samples[\"answers\"] = answers\n        samples[\"ans_to_cap_dict\"] = ans_to_cap_dict\n        # results.append({\"question_id\": ques_id, \"question\":questions,\"answer\":answers})\n        return samples\n\n    def create_context_prompt(self, samples, num_caps_per_img=30):\n        ans_dict_queid = samples[\"ans_to_cap_dict\"]\n        # print(ans_dict_queid)\n        caption = samples[\"captions\"][0]\n        answers = samples[\"answers\"]\n        Context_Prompt = \"\"\n        mycontexts_id = []\n        for idx in range(num_caps_per_img):\n            cap_id_list = ans_dict_queid.get(\n                answers[(len(answers) - 1 - idx) % len(answers)][:-1].lower(), [0]\n            )\n            for cap_id in cap_id_list:\n                if cap_id not in mycontexts_id:\n                    Context_Prompt += caption[cap_id]\n                    mycontexts_id.append(cap_id)\n                    break  # We just take one cap for each answer\n        samples[\"Context_Prompt\"] = Context_Prompt\n        return Context_Prompt\n\n    def create_task_prompt(\n        self, samples, question_type=\"neural\", num_question_per_img=30\n    ):\n        syn_question_queid = samples[\"questions\"]\n        syn_ans_queid = samples[\"answers\"]\n        Task_Prompt = \"\"\n        for idx in range(num_question_per_img):\n            # if config['random_question']:\n            #     qa_idx = random.randint(0, len(syn_question_queid) - 1)\n            # else:\n            qa_idx = idx\n            if (\n                question_type != \"rule\" and num_question_per_img > 0 and idx < 1\n            ):  ## yes and no questions for vqav2\n                # Task_Prompt += \"Question:\"\n                # Task_Prompt += syn_question_queid_next[-1]\n                # Task_Prompt += '\\n'\n                # Task_Prompt += \"Answer:no\\n\"\n                Task_Prompt += \"Question:\"\n                Task_Prompt += syn_question_queid[-1]\n                Task_Prompt += \"\\n\"\n                Task_Prompt += \"Answer:\"\n                Task_Prompt += \"yes\\n\"\n                Task_Prompt += \"Question:Is this a toilet?\\n\"\n                Task_Prompt += \"Answer:no\\n\"\n            if \"question_type\" == \"rule\":  # Rule-Based Question Generation\n                Noun_Questions = [\n                    \"What item is this in this picture?\",\n                    \"What item is that in this picture?\",\n                ]\n\n                Verb_Questions = [\n                    \"What action is being done in this picture?\",\n                    \"Why is this item doing in this picture?\",\n                    \"Which action is being taken in this picture?\",\n                    \"What action is item doing in this picture?\",\n                    \"What action is item performing in this picture?\",\n                ]\n\n                Adj_Questions = [\n                    \"How to describe one item in this picture?\",\n                    \"What is item's ADJ TYPE in this picture?\",\n                    \"What is the ADJ TYPE in this picture?\",\n                ]\n\n                Task_Prompt += \"Question:\"\n                doc = self.nlp(syn_ans_queid[(qa_idx) % len(syn_ans_queid)][:-1].lower())\n                if doc[-1].pos_ == \"NOUN\":\n                    Task_Prompt += Noun_Questions[\n                        random.randint(0, len(Noun_Questions) - 1)\n                    ]\n                elif doc[-1].pos_ == \"VERB\":\n                    Task_Prompt += Verb_Questions[\n                        random.randint(0, len(Verb_Questions) - 1)\n                    ]\n                elif doc[-1].pos_ == \"ADJ\":\n                    Task_Prompt += Adj_Questions[\n                        random.randint(0, len(Adj_Questions) - 1)\n                    ]\n\n                Task_Prompt += \"\\n\"\n\n                Task_Prompt += \"Answer:\"\n                Task_Prompt += syn_ans_queid[(qa_idx) % len(syn_ans_queid)][:-1].lower()\n                Task_Prompt += \"\\n\"\n        samples[\"Task_Prompt\"] = Task_Prompt\n        # print(Task_Prompt)\n        return Task_Prompt\n\n    def prompts_construction(\n        self,\n        samples,\n        question_type=\"neural\",\n        num_caps_per_img=30,\n        num_question_per_img=30,\n    ):\n        Prompt = \"Please reason the answer of the questions according to the given contexts.\\n\"\n\n        Context_Prompt = self.create_context_prompt(samples, num_caps_per_img)\n\n        Task_Prompt = self.create_task_prompt(\n            samples, question_type, num_question_per_img\n        )\n\n        Img2Prompt = (\n            Prompt\n            + \"Contexts:\"\n            + Context_Prompt\n            + \"\\n\"\n            + Task_Prompt\n            + \"Question:\"\n            + samples[\"text_input\"][0]\n            + \"\\nAnswer:\"\n        )\n        return Img2Prompt\n\n    def prepare_LLM_input(\n        self,\n        samples,\n        num_beams=1,\n        inference_method=\"generate\",\n        max_len=20,\n        min_len=0,\n        internal_bsz_fid=1,\n        num_captions=50,\n        num_captions_fid=1,\n        cap_max_length=20,\n        cap_min_length=10,\n        top_k=50,\n        top_p=1,\n        repetition_penalty=1,\n        num_patches=20,\n        block_num=7,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (str or [str]): String or a list of strings, each string is a question.\n                                             The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            inference_method (str): Inference method. Must be \"generate\". The model will generate answers.\n            max_len (int): Maximum length of generated answers.\n            min_len (int): Minimum length of generated answers.\n            internal_bsz_fid (int): Internal batch size when using FiD decoding.\n            num_captions (int): Number of captions generated for each image.\n            num_captions_fid (int): Number of captions concatenated with a question during FiD decoding.\n            cap_max_length (int): The maximum length of the caption to be generated.\n            cap_min_length (int): The minimum length of the caption to be generated.\n            top_k (float): The number of the highest probability tokens for top-k sampling.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_patches (int): Number of patches sampled for each image.\n            block_num (int): The index of cross-attention block for gradcam computation.\n\n        Returns:\n            List: A list of strings, each string is an answer.\n            gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n            captions (nested list): A nested list of strings of total length batch_size * num_captions\n        \"\"\"\n        assert inference_method in [\n            \"generate\",\n        ], \"Inference method must be 'generate', got {}.\".format(inference_method)\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n\n        assert len(samples[\"text_input\"]) == samples[\"image\"].size(\n            0\n        ), \"The number of questions must be equal to the batch size.\"\n\n        samples = self.forward_itm(samples, block_num=block_num)\n\n        samples = self.forward_cap(\n            samples,\n            cap_max_length=cap_max_length,\n            cap_min_length=cap_min_length,\n            top_k=top_k,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n            num_captions=num_captions,\n            num_patches=num_patches,\n        )\n\n        if self.offload_model:\n            samples[\"image\"] = samples[\"image\"].to(\"cpu\")\n            self.image_question_matching_model.to(\"cpu\")\n            self.image_captioning_model.to(\"cpu\")\n        torch.cuda.empty_cache()\n\n        pred_answers = self.forward_qa(\n            samples,\n            num_beams=num_beams,\n            max_len=max_len,\n            min_len=min_len,\n            internal_bsz_fid=internal_bsz_fid,\n            num_captions=num_captions,\n            num_captions_fid=num_captions_fid,\n        )\n\n        if self.offload_model:\n            self.image_question_matching_model.to(self.question_answering_model.device)\n            self.image_captioning_model.to(self.question_answering_model.device)\n\n        return pred_answers, samples[\"captions\"], samples[\"gradcams\"]\n\n    @classmethod\n    def from_config(cls, model_config):\n        itm_config = model_config.image_question_matching_model\n        cap_config = model_config.image_captioning_model\n\n        itm_cls = registry.get_model_class(itm_config.arch)\n        cap_cls = registry.get_model_class(cap_config.arch)\n\n        image_question_matching_model = itm_cls.from_config(itm_config)\n        image_captioning_model = cap_cls.from_config(cap_config)\n\n        question_generation_tokenizer = T5Tokenizer.from_pretrained(\n            \"google/t5-large-lm-adapt\"\n        )\n        question_generation_model = T5ForConditionalGeneration.from_pretrained(\n            \"google/t5-large-lm-adapt\"\n        )\n        cached_file = download_cached_file(\n            \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/img2prompt/T5_large_QG.pth\",\n            check_hash=False,\n            progress=True,\n        )\n        checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        state_dict = checkpoint[\"model\"]\n        question_generation_model.load_state_dict(state_dict)\n        model = cls(\n            image_question_matching_model=image_question_matching_model,\n            image_captioning_model=image_captioning_model,\n            question_generation_model=question_generation_model,\n            question_generation_tokenizer=question_generation_tokenizer,\n            offload_model=False,\n        )\n\n        return model\n"
  },
  {
    "path": "lavis/models/med.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n \n Based on huggingface code base\n https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert\n\"\"\"\n\nimport math\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor, device\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nimport torch.nn.functional as F\nfrom transformers import BatchEncoding, PreTrainedTokenizer\n\nfrom transformers.activations import ACT2FN\nfrom transformers.file_utils import (\n    ModelOutput,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    NextSentencePredictorOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.utils import logging\nfrom transformers.models.bert.configuration_bert import BertConfig\nfrom lavis.common.utils import get_abs_path\n\nfrom lavis.models.base_model import BaseEncoder\n\nlogging.set_verbosity_error()\nlogger = logging.get_logger(__name__)\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word and position embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id\n        )\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size\n        )\n\n        if config.add_type_embeddings:\n            self.token_type_embeddings = nn.Embedding(\n                config.type_vocab_size, config.hidden_size\n            )\n\n        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n        # any TensorFlow checkpoint file\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1))\n        )\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n\n        self.config = config\n\n    def forward(\n        self,\n        input_ids=None,\n        token_type_ids=None,\n        position_ids=None,\n        inputs_embeds=None,\n        past_key_values_length=0,\n    ):\n        if input_ids is not None:\n            input_shape = input_ids.size()\n        else:\n            input_shape = inputs_embeds.size()[:-1]\n\n        seq_length = input_shape[1]\n\n        if position_ids is None:\n            position_ids = self.position_ids[\n                :, past_key_values_length : seq_length + past_key_values_length\n            ]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        if token_type_ids is not None:\n            token_type_embeddings = self.token_type_embeddings(token_type_ids)\n\n            embeddings = inputs_embeds + token_type_embeddings\n        else:\n            embeddings = inputs_embeds\n\n        if self.position_embedding_type == \"absolute\":\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings += position_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, is_cross_attention):\n        super().__init__()\n        self.config = config\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(\n            config, \"embedding_size\"\n        ):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        if is_cross_attention:\n            self.key = nn.Linear(config.encoder_width, self.all_head_size)\n            self.value = nn.Linear(config.encoder_width, self.all_head_size)\n        else:\n            self.key = nn.Linear(config.hidden_size, self.all_head_size)\n            self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(\n                2 * config.max_position_embeddings - 1, self.attention_head_size\n            )\n        self.save_attention = False\n\n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n\n    def get_attn_gradients(self):\n        return self.attn_gradients\n\n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n\n    def get_attention_map(self):\n        return self.attention_map\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (\n            self.num_attention_heads,\n            self.attention_head_size,\n        )\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        mixed_query_layer = self.query(hidden_states)\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            seq_length = hidden_states.size()[1]\n            position_ids_l = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(-1, 1)\n            position_ids_r = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(\n                distance + self.max_position_embeddings - 1\n            )\n            positional_embedding = positional_embedding.to(\n                dtype=query_layer.dtype\n            )  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                relative_position_scores_key = torch.einsum(\n                    \"bhrd,lrd->bhlr\", key_layer, positional_embedding\n                )\n                attention_scores = (\n                    attention_scores\n                    + relative_position_scores_query\n                    + relative_position_scores_key\n                )\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        if is_cross_attention and self.save_attention:\n            self.save_attention_map(attention_probs)\n            attention_probs.register_hook(self.save_attn_gradients)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs_dropped = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs_dropped = attention_probs_dropped * head_mask\n\n        context_layer = torch.matmul(attention_probs_dropped, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (\n            (context_layer, attention_probs) if output_attentions else (context_layer,)\n        )\n\n        outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, is_cross_attention=False):\n        super().__init__()\n        self.self = BertSelfAttention(config, is_cross_attention)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads,\n            self.self.num_attention_heads,\n            self.self.attention_head_size,\n            self.pruned_heads,\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = (\n            self.self.attention_head_size * self.self.num_attention_heads\n        )\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n        outputs = (attention_output,) + self_outputs[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_num):\n        super().__init__()\n        self.config = config\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n        self.layer_num = layer_num\n\n        # compatibility for ALBEF and BLIP\n        try:\n            # ALBEF & ALPRO\n            fusion_layer = self.config.fusion_layer\n            add_cross_attention = (\n                fusion_layer <= layer_num and self.config.add_cross_attention\n            )\n\n            self.fusion_layer = fusion_layer\n        except AttributeError:\n            # BLIP\n            self.fusion_layer = self.config.num_hidden_layers\n            add_cross_attention = self.config.add_cross_attention\n\n        # if self.config.add_cross_attention:\n        if add_cross_attention:\n            self.crossattention = BertAttention(\n                config, is_cross_attention=self.config.add_cross_attention\n            )\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        mode=None,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = (\n            past_key_value[:2] if past_key_value is not None else None\n        )\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n\n        outputs = self_attention_outputs[1:-1]\n        present_key_value = self_attention_outputs[-1]\n\n        # TODO line 482 in albef/models/xbert.py\n        # compatibility for ALBEF and BLIP\n        if mode in [\"multimodal\", \"fusion\"] and hasattr(self, \"crossattention\"):\n            assert (\n                encoder_hidden_states is not None\n            ), \"encoder_hidden_states must be given for cross-attention layers\"\n\n            if isinstance(encoder_hidden_states, list):\n                cross_attention_outputs = self.crossattention(\n                    attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states[\n                        (self.layer_num - self.fusion_layer)\n                        % len(encoder_hidden_states)\n                    ],\n                    encoder_attention_mask[\n                        (self.layer_num - self.fusion_layer)\n                        % len(encoder_hidden_states)\n                    ],\n                    output_attentions=output_attentions,\n                )\n                attention_output = cross_attention_outputs[0]\n                outputs = outputs + cross_attention_outputs[1:-1]\n\n            else:\n                cross_attention_outputs = self.crossattention(\n                    attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    output_attentions=output_attentions,\n                )\n                attention_output = cross_attention_outputs[0]\n                outputs = (\n                    outputs + cross_attention_outputs[1:-1]\n                )  # add cross attentions if we output attention weights\n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk,\n            self.chunk_size_feed_forward,\n            self.seq_len_dim,\n            attention_output,\n        )\n        outputs = (layer_output,) + outputs\n\n        outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList(\n            [BertLayer(config, i) for i in range(config.num_hidden_layers)]\n        )\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        mode=\"multimodal\",\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = (\n            () if output_attentions and self.config.add_cross_attention else None\n        )\n\n        next_decoder_cache = () if use_cache else None\n\n        try:\n            # ALBEF\n            fusion_layer = self.config.fusion_layer\n        except AttributeError:\n            # BLIP\n            fusion_layer = self.config.num_hidden_layers\n\n        if mode == \"text\":\n            start_layer = 0\n            # output_layer = self.config.fusion_layer\n            output_layer = fusion_layer\n\n        elif mode == \"fusion\":\n            # start_layer = self.config.fusion_layer\n            start_layer = fusion_layer\n            output_layer = self.config.num_hidden_layers\n\n        elif mode == \"multimodal\":\n            start_layer = 0\n            output_layer = self.config.num_hidden_layers\n\n        # compatibility for ALBEF and BLIP\n        # for i in range(self.config.num_hidden_layers):\n        for i in range(start_layer, output_layer):\n            layer_module = self.layer[i]\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            # TODO pay attention to this.\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warn(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs, past_key_value, output_attentions)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    mode=mode,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    mode=mode,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertOnlyMLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n\n    def forward(self, sequence_output):\n        prediction_scores = self.predictions(sequence_output)\n        return prediction_scores\n\n\nclass BertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = BertConfig\n    base_model_prefix = \"bert\"\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n\nclass BertModel(BertPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config, add_pooling_layer=True):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n\n        self.encoder = BertEncoder(config)\n\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def get_extended_attention_mask(\n        self,\n        attention_mask: Tensor,\n        input_shape: Tuple[int],\n        device: device,\n        is_decoder: bool,\n    ) -> Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple[int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if attention_mask.dim() == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif attention_mask.dim() == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if is_decoder:\n                batch_size, seq_length = input_shape\n\n                seq_ids = torch.arange(seq_length, device=device)\n                causal_mask = (\n                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)\n                    <= seq_ids[None, :, None]\n                )\n                # in case past_key_values are used we need to add a prefix ones mask to the causal mask\n                # causal and attention masks must have same type with pytorch version < 1.3\n                causal_mask = causal_mask.to(attention_mask.dtype)\n\n                if causal_mask.shape[1] < attention_mask.shape[1]:\n                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                    causal_mask = torch.cat(\n                        [\n                            torch.ones(\n                                (batch_size, seq_length, prefix_seq_len),\n                                device=device,\n                                dtype=causal_mask.dtype,\n                            ),\n                            causal_mask,\n                        ],\n                        axis=-1,\n                    )\n\n                extended_attention_mask = (\n                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n                )\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                \"Wrong shape for input_ids (shape {}) or attention_mask (shape {})\".format(\n                    input_shape, attention_mask.shape\n                )\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.to(\n            dtype=self.dtype\n        )  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n        mode=\"multimodal\",\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if is_decoder:\n            use_cache = use_cache if use_cache is not None else self.config.use_cache\n        else:\n            use_cache = False\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            batch_size, seq_length = input_shape\n            device = input_ids.device\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = inputs_embeds.device\n        elif encoder_embeds is not None:\n            input_shape = encoder_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = encoder_embeds.device\n        else:\n            raise ValueError(\n                \"You have to specify either input_ids or inputs_embeds or encoder_embeds\"\n            )\n\n        # past_key_values_length\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        )\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                ((batch_size, seq_length + past_key_values_length)), device=device\n            )\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(\n            attention_mask, input_shape, device, is_decoder\n        )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if encoder_hidden_states is not None:\n            if type(encoder_hidden_states) == list:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[\n                    0\n                ].size()\n            else:\n                (\n                    encoder_batch_size,\n                    encoder_sequence_length,\n                    _,\n                ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n\n            if type(encoder_attention_mask) == list:\n                encoder_extended_attention_mask = [\n                    self.invert_attention_mask(mask) for mask in encoder_attention_mask\n                ]\n            elif encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n            else:\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if encoder_embeds is None:\n            embedding_output = self.embeddings(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                token_type_ids=token_type_ids,\n                inputs_embeds=inputs_embeds,\n                past_key_values_length=past_key_values_length,\n            )\n        else:\n            embedding_output = encoder_embeds\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            mode=mode,\n        )\n        sequence_output = encoder_outputs[0]\n        pooled_output = (\n            self.pooler(sequence_output) if self.pooler is not None else None\n        )\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n\n\nclass BertForMaskedLM(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        # token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n        mode=\"multimodal\",\n        soft_labels=None,\n        alpha=0,\n        return_logits=False,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            # token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_embeds=encoder_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n            mode=mode,\n        )\n\n        sequence_output = outputs[0]\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores\n\n        masked_lm_loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()  # -100 index = padding token\n            masked_lm_loss = loss_fct(\n                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)\n            )\n\n        if soft_labels is not None:\n            loss_distill = -torch.sum(\n                F.log_softmax(prediction_scores, dim=-1) * soft_labels, dim=-1\n            )\n            loss_distill = loss_distill[labels != -100].mean()\n            masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return (\n                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n            )\n\n        return MaskedLMOutput(\n            loss=masked_lm_loss,\n            logits=prediction_scores,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, attention_mask=None, **model_kwargs\n    ):\n        input_shape = input_ids.shape\n        effective_batch_size = input_shape[0]\n\n        #  add a dummy token\n        assert (\n            self.config.pad_token_id is not None\n        ), \"The PAD token should be defined for generation\"\n        attention_mask = torch.cat(\n            [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],\n            dim=-1,\n        )\n        dummy_token = torch.full(\n            (effective_batch_size, 1),\n            self.config.pad_token_id,\n            dtype=torch.long,\n            device=input_ids.device,\n        )\n        input_ids = torch.cat([input_ids, dummy_token], dim=1)\n\n        return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n\nclass BertLMHeadModel(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        return_logits=False,\n        is_decoder=True,\n        reduction=\"mean\",\n        mode=\"multimodal\",\n        soft_labels=None,\n        alpha=0,\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are\n            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n            >>> config = BertConfig.from_pretrained(\"bert-base-cased\")\n            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)\n            >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n            >>> outputs = model(**inputs)\n            >>> prediction_logits = outputs.logits\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n        if labels is not None:\n            use_cache = False\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n            mode=mode,\n        )\n\n        sequence_output = outputs[0]\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores[:, :-1, :].contiguous()\n\n        lm_loss = None\n        if labels is not None:\n            # we are doing next-token prediction; shift prediction scores and input ids by one\n            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()\n            labels = labels[:, 1:].contiguous()\n            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)\n            lm_loss = loss_fct(\n                shifted_prediction_scores.view(-1, self.config.vocab_size),\n                labels.view(-1),\n            )\n            if reduction == \"none\":\n                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)\n\n        if soft_labels is not None:\n            loss_distill = -torch.sum(\n                F.log_softmax(shifted_prediction_scores, dim=-1) * soft_labels, dim=-1\n            )\n            loss_distill = (loss_distill * (labels != -100)).sum(1)\n            lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return ((lm_loss,) + output) if lm_loss is not None else output\n\n        return CausalLMOutputWithCrossAttentions(\n            loss=lm_loss,\n            logits=prediction_scores,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            cross_attentions=outputs.cross_attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past=None, attention_mask=None, **model_kwargs\n    ):\n        input_shape = input_ids.shape\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            attention_mask = input_ids.new_ones(input_shape)\n\n        # cut decoder_input_ids if past is used\n        if past is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"past_key_values\": past,\n            \"encoder_hidden_states\": model_kwargs.get(\"encoder_hidden_states\", None),\n            \"encoder_attention_mask\": model_kwargs.get(\"encoder_attention_mask\", None),\n            \"is_decoder\": True,\n        }\n\n    def _reorder_cache(self, past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\nclass XBertLMHeadDecoder(BertLMHeadModel):\n    \"\"\"\n    This class decouples the decoder forward logic from the VL model.\n    In this way, different VL models can share this decoder as long as\n    they feed encoder_embeds as required.\n    \"\"\"\n\n    @classmethod\n    def from_config(cls, cfg, from_pretrained=False):\n\n        med_config_path = get_abs_path(cfg.get(\"med_config_path\"))\n        med_config = BertConfig.from_json_file(med_config_path)\n\n        if from_pretrained:\n            return cls.from_pretrained(\"bert-base-uncased\", config=med_config)\n        else:\n            return cls(config=med_config)\n\n    def generate_from_encoder(\n        self,\n        tokenized_prompt,\n        visual_embeds,\n        sep_token_id,\n        pad_token_id,\n        use_nucleus_sampling=False,\n        num_beams=3,\n        max_length=30,\n        min_length=10,\n        top_p=0.9,\n        repetition_penalty=1.0,\n        **kwargs\n    ):\n\n        if not use_nucleus_sampling:\n            num_beams = num_beams\n            visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0)\n\n        image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        model_kwargs = {\n            \"encoder_hidden_states\": visual_embeds,\n            \"encoder_attention_mask\": image_atts,\n        }\n\n        if use_nucleus_sampling:\n            # nucleus sampling\n            outputs = self.generate(\n                input_ids=tokenized_prompt.input_ids,\n                max_length=max_length,\n                min_length=min_length,\n                do_sample=True,\n                top_p=top_p,\n                num_return_sequences=1,\n                eos_token_id=sep_token_id,\n                pad_token_id=pad_token_id,\n                repetition_penalty=1.1,\n                **model_kwargs\n            )\n        else:\n            # beam search\n            outputs = self.generate(\n                input_ids=tokenized_prompt.input_ids,\n                max_length=max_length,\n                min_length=min_length,\n                num_beams=num_beams,\n                eos_token_id=sep_token_id,\n                pad_token_id=pad_token_id,\n                repetition_penalty=repetition_penalty,\n                **model_kwargs\n            )\n\n        return outputs\n\n\nclass XBertEncoder(BertModel, BaseEncoder):\n    @classmethod\n    def from_config(cls, cfg, from_pretrained=False):\n\n        med_config_path = get_abs_path(cfg.get(\"med_config_path\"))\n        med_config = BertConfig.from_json_file(med_config_path)\n\n        if from_pretrained:\n            return cls.from_pretrained(\n                \"bert-base-uncased\", config=med_config, add_pooling_layer=False\n            )\n        else:\n            return cls(config=med_config, add_pooling_layer=False)\n\n    def forward_automask(self, tokenized_text, visual_embeds, **kwargs):\n        image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to(\n            self.device\n        )\n\n        text = tokenized_text\n        text_output = super().forward(\n            text.input_ids,\n            attention_mask=text.attention_mask,\n            encoder_hidden_states=visual_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        return text_output\n\n    def forward_text(self, tokenized_text, **kwargs):\n        text = tokenized_text\n        token_type_ids = kwargs.get(\"token_type_ids\", None)\n\n        text_output = super().forward(\n            text.input_ids,\n            attention_mask=text.attention_mask,\n            token_type_ids=token_type_ids,\n            return_dict=True,\n            mode=\"text\",\n        )\n\n        return text_output\n"
  },
  {
    "path": "lavis/models/pnp_vqa_models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\n\n\ndef prepare_qa_input(sample, num_captions, num_captions_fid):\n    sample_question_captions = []\n\n    for question, captions in zip(sample['text_input'], sample['captions']):\n        assert isinstance(captions, list)\n        question_captions = []\n        question_caption = ''\n        for cap_id, cap_ in enumerate(captions[0:num_captions]):\n            question_caption += (cap_.strip() + '. ')\n            if (cap_id + 1) != num_captions and ((cap_id + 1) % num_captions_fid == 0):\n                question_caption = question.lower().strip() + \" \\\\n \" + question_caption.lower().strip()\n                question_captions.append(question_caption)\n                question_caption = ''\n            if (cap_id + 1) == num_captions:\n                question_caption = question.lower().strip() + \" \\\\n \" + question_caption.lower().strip()\n                question_captions.append(question_caption)\n        sample_question_captions.append(question_captions)\n\n    sample['question_captions'] = sample_question_captions\n"
  },
  {
    "path": "lavis/models/pnp_vqa_models/pnp_unifiedqav2_fid.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on facebookresearch code base\n https://github.com/facebookresearch/FiD\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import BaseModel\nfrom lavis.common.utils import get_abs_path\nfrom transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration\n\n\n@registry.register_model(\"pnp_unifiedqav2_fid\")\nclass PNPUnifiedQAv2FiD(T5ForConditionalGeneration, BaseModel):\n\n    PRETRAINED_MODEL_CONFIG_DICT = {}\n\n    def __init__(self, config, model_path):\n        super().__init__(config)\n        \n        self.tokenizer = T5Tokenizer.from_pretrained(model_path)\n\n    def forward(self, input_ids=None, attention_mask=None, **kwargs):\n        if input_ids != None:\n            if input_ids.dim() == 3:\n                self.encoder.num_contexts = input_ids.size(1)\n            input_ids = input_ids.view(input_ids.size(0), -1)\n        if attention_mask != None:\n            attention_mask = attention_mask.view(attention_mask.size(0), -1)\n\n        return super().forward(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            **kwargs\n        )\n\n    def generate(self, input_ids, attention_mask, num_beams=1, min_length=0, max_length=20):\n        self.encoder.num_contexts = input_ids.size(1)\n\n        return super().generate(\n            input_ids=input_ids.view(input_ids.size(0), -1),\n            attention_mask=attention_mask.view(attention_mask.size(0), -1),\n            num_beams=num_beams,\n            min_length=min_length,\n            max_length=max_length\n        )\n\n    def load_unifiedqa(self, state_dict):\n        self.load_state_dict(state_dict)\n        self.encoder = T5EncoderWrapper(self.encoder)\n\n    @classmethod\n    def from_config(cls, cfg):\n        model_path = cfg.get('pretrained')\n        t5_config_path = get_abs_path(cfg.get(\"t5_config_path\"))\n        t5_config = T5Config.from_json_file(t5_config_path)\n        model = cls(t5_config, model_path)\n        model.load_unifiedqa(T5ForConditionalGeneration.from_pretrained(model_path).state_dict())\n\n        return model\n\n\nclass T5EncoderWrapper(torch.nn.Module):\n\n    def __init__(self, encoder):\n        super().__init__()\n\n        self.encoder = encoder\n        self.block = self.encoder.block\n        self.parallelize = self.encoder.parallelize\n        self.main_input_name = encoder.main_input_name\n\n    def forward(self, input_ids=None, attention_mask=None, **kwargs):\n        bsz, total_length = input_ids.shape\n        context_length = total_length // self.num_contexts\n        input_ids = input_ids.view(bsz*self.num_contexts, context_length)\n        attention_mask = attention_mask.view(bsz*self.num_contexts, context_length)\n        outputs = self.encoder(input_ids, attention_mask, **kwargs)\n        outputs = (outputs[0].view(bsz, self.num_contexts*context_length, -1), ) + outputs[1:]\n\n        return outputs"
  },
  {
    "path": "lavis/models/pnp_vqa_models/pnp_vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom itertools import chain\nfrom lavis.common.registry import registry\nfrom lavis.models.base_model import BaseModel\nfrom torch.nn import CrossEntropyLoss, MSELoss\nfrom transformers import T5ForConditionalGeneration\nfrom lavis.models.pnp_vqa_models import prepare_qa_input\nfrom lavis.models.blip_models.blip_image_text_matching import compute_gradcam\nfrom transformers.modeling_outputs import CausalLMOutputWithCrossAttentions\n\n\n@registry.register_model(\"pnp_vqa\")\nclass PNPVQA(BaseModel):\n    \"\"\"\n    PNPVQA model consists of three submodels for zero-shot VQA:\n        1. Image-questioning matching model\n        2. Image captioning model\n        3. Question answering model\n\n    Supported model types:\n        - base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base)\n        - large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large)\n        - 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b)\n\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"pnp_vqa\", \"base\", is_eval=True)\n        >>> model = load_model(\"pnp_vqa\", \"large\", is_eval=True)\n        >>> model = load_model(\"pnp_vqa\", \"3b\", is_eval=True)\n    \"\"\"\n\n    PRETRAINED_MODEL_CONFIG_DICT = {\"base\": \"configs/models/pnp-vqa/pnp_vqa_base.yaml\",\n                                    \"large\": \"configs/models/pnp-vqa/pnp_vqa_large.yaml\",\n                                    \"3b\": \"configs/models/pnp-vqa/pnp_vqa_3b.yaml\",\n                                    }\n\n    def __init__(self, image_question_matching_model, image_captioning_model,\n                 question_answering_model, offload_model=False):\n        super().__init__()\n\n        self.image_question_matching_model = image_question_matching_model\n        self.image_captioning_model = image_captioning_model\n        self.question_answering_model = question_answering_model\n        self.offload_model = offload_model\n\n    def forward_itm(self, samples, block_num=7):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n            block_num (int): The index of cross-attention block for gradcam computation.\n\n        Returns:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n        \"\"\"\n        image = samples['image']\n        question = [text.strip('?') for text in samples['text_input']]\n        tokenized_text = self.image_question_matching_model.tokenizer(question, padding='longest', truncation=True,\n                                                return_tensors=\"pt\").to(self.image_question_matching_model.device)\n        with torch.set_grad_enabled(True):\n            gradcams, _ = compute_gradcam(model=self.image_question_matching_model,\n                            visual_input=image,\n                            text_input=question,\n                            tokenized_text=tokenized_text,\n                            block_num=block_num)\n\n        gradcams = [gradcam_[1] for gradcam_ in gradcams]\n        samples['gradcams'] = torch.stack(gradcams).reshape(samples['image'].size(0), -1)\n\n        return samples\n\n    def forward_cap(\n            self,\n            samples,\n            cap_max_length=20,\n            cap_min_length=0,\n            top_p=1,\n            top_k=50,\n            repetition_penalty=1.0,\n            num_captions=100,\n            num_patches=20,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n            cap_max_length (int): The maximum length of the caption to be generated.\n            cap_min_length (int): The minimum length of the caption to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            top_k (float): The number of the highest probability tokens for top-k sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions generated for each image.\n            num_patches (int): Number of patches sampled for each image.\n\n        Returns:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n                - captions (nested list): A nested list of strings of total length batch_size * num_captions\n        \"\"\"\n        encoder_out = self.image_captioning_model.forward_encoder(samples)\n        captions = [[] for _ in range(encoder_out.size(0))]\n\n        min_num_captions = 0\n\n        while min_num_captions < num_captions:\n            encoder_out_samples = []\n            for i in range(num_captions):\n                patch_id = torch.multinomial(samples['gradcams'].to(self.image_captioning_model.device),\n                                             num_patches).reshape(encoder_out.size(0), -1) + 1\n                patch_id = patch_id.sort(dim=1).values.unsqueeze(-1).expand(-1, -1, encoder_out.size(2))\n                encoder_out_sample = torch.gather(encoder_out, 1, patch_id)\n                encoder_out_samples.append(encoder_out_sample)\n\n            stacked = torch.stack(encoder_out_samples, dim=1)\n            image_embeds = torch.flatten(stacked, start_dim=0, end_dim=1) #(bsz*num_seq, num_patch, dim)\n\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.image_captioning_model.device)\n            model_kwargs = {\n                \"encoder_hidden_states\": image_embeds,\n                \"encoder_attention_mask\": image_atts,\n            }\n\n            prompt = [self.image_captioning_model.prompt] * image_embeds.size(0)\n            prompt = self.image_captioning_model.tokenizer(prompt,\n                                                           return_tensors=\"pt\").to(self.image_captioning_model.device)\n            prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id\n            prompt.input_ids = prompt.input_ids[:, :-1]\n\n            decoder_out = self.image_captioning_model.text_decoder.generate(\n                input_ids=prompt.input_ids,\n                max_length=cap_max_length,\n                min_length=cap_min_length,\n                do_sample=True,\n                top_p=top_p,\n                top_k=top_k,\n                num_return_sequences=1,\n                eos_token_id=self.image_captioning_model.tokenizer.sep_token_id,\n                pad_token_id=self.image_captioning_model.tokenizer.pad_token_id,\n                repetition_penalty=repetition_penalty,\n                **model_kwargs)\n\n            outputs = self.image_captioning_model.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)\n\n            for counter, output in enumerate(outputs):\n                ind = counter//num_captions\n                if len(captions[ind]) < num_captions:\n                    caption = output[len(self.image_captioning_model.prompt):]\n                    overlap_caption = [1 for caps in captions[ind] if caption in caps]\n                    if len(overlap_caption) == 0:\n                        captions[ind].append(caption)\n\n            min_num_captions = min([len(i) for i in captions])\n\n        samples['captions'] = captions\n\n        return samples\n\n    def forward_qa(\n            self,\n            samples,\n            num_beams=1,\n            max_len=20,\n            min_len=0,\n            internal_bsz_fid=1,\n            num_captions=100,\n            num_captions_fid=1,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n                - text_input (list): A list of strings of length batch_size\n                - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n                - captions (nested list): A nested list of strings of total length batch_size * num_captions\n                - question_captions (nested list): A nested list of concatenated strings of questions and captions\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_len (int): Maximum length of generated answers.\n            min_len (int): Minimum length of generated answers.\n            internal_bsz_fid (int): Internal batch size when using FiD decoding.\n            num_captions (int): Number of captions generated for each image.\n            num_captions_fid (int): Number of captions concatenated with a question during FiD decoding.\n\n        Returns:\n            List: A list of strings, each string is an answer.\n        \"\"\"\n        prepare_qa_input(samples, num_captions=num_captions, num_captions_fid=num_captions_fid)\n\n        pred_answers = []\n        question_captions = samples['question_captions']\n        question_captions_chunk = [question_captions[i:i + internal_bsz_fid]\n                                   for i in range(0, len(question_captions), internal_bsz_fid)]\n        question_captions_chunk = list(chain(*question_captions_chunk))\n\n        for question_caption in question_captions_chunk:\n            question_caption_input = self.question_answering_model.tokenizer(question_caption, padding='longest',\n                                        truncation=True, return_tensors=\"pt\").to(self.question_answering_model.device)\n\n            question_caption_input.input_ids = question_caption_input.input_ids.reshape(\n                                               internal_bsz_fid, -1, question_caption_input.input_ids.size(1))\n            question_caption_input.attention_mask = question_caption_input.attention_mask.reshape(\n                                               internal_bsz_fid, -1, question_caption_input.attention_mask.size(1))\n\n            outputs = self.question_answering_model.generate(input_ids=question_caption_input.input_ids,\n                                            attention_mask=question_caption_input.attention_mask,\n                                            num_beams=num_beams,\n                                            min_length=min_len,\n                                            max_length=max_len,\n                                            )\n\n            for output in outputs:\n                pred_answer = self.question_answering_model.tokenizer.decode(output, skip_special_tokens=True)\n                pred_answers.append(pred_answer)\n\n        return pred_answers\n\n    def predict_answers(\n        self,\n        samples,\n        num_beams=1,\n        inference_method=\"generate\",\n        max_len=20,\n        min_len=0,\n        internal_bsz_fid=1,\n        num_captions=50,\n        num_captions_fid=1,\n        cap_max_length=20,\n        cap_min_length=10,\n        top_k=50,\n        top_p=1,\n        repetition_penalty=1,\n        num_patches=50,\n        block_num=7,\n    ):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.\n                - text_input (str or [str]): String or a list of strings, each string is a question.\n                                             The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            inference_method (str): Inference method. Must be \"generate\". The model will generate answers.\n            max_len (int): Maximum length of generated answers.\n            min_len (int): Minimum length of generated answers.\n            internal_bsz_fid (int): Internal batch size when using FiD decoding.\n            num_captions (int): Number of captions generated for each image.\n            num_captions_fid (int): Number of captions concatenated with a question during FiD decoding.\n            cap_max_length (int): The maximum length of the caption to be generated.\n            cap_min_length (int): The minimum length of the caption to be generated.\n            top_k (float): The number of the highest probability tokens for top-k sampling.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_patches (int): Number of patches sampled for each image.\n            block_num (int): The index of cross-attention block for gradcam computation.\n\n        Returns:\n            List: A list of strings, each string is an answer.\n            gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)\n            captions (nested list): A nested list of strings of total length batch_size * num_captions\n        \"\"\"\n        assert inference_method in [\n            \"generate\",\n        ], \"Inference method must be 'generate', got {}.\".format(\n            inference_method\n        )\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n\n        assert len(samples[\"text_input\"]) == samples[\"image\"].size(\n            0\n        ), \"The number of questions must be equal to the batch size.\"\n\n        samples = self.forward_itm(samples, block_num=block_num)\n\n        samples = self.forward_cap(samples,\n                                   cap_max_length=cap_max_length,\n                                   cap_min_length=cap_min_length,\n                                   top_k=top_k,\n                                   top_p=top_p,\n                                   repetition_penalty=repetition_penalty,\n                                   num_captions=num_captions,\n                                   num_patches=num_patches)\n\n        if self.offload_model:\n            samples['image'] = samples['image'].to('cpu')\n            self.image_question_matching_model.to('cpu')\n            self.image_captioning_model.to('cpu')\n        torch.cuda.empty_cache()\n\n        pred_answers = self.forward_qa(samples,\n                                  num_beams=num_beams,\n                                  max_len=max_len,\n                                  min_len=min_len,\n                                  internal_bsz_fid=internal_bsz_fid,\n                                  num_captions=num_captions,\n                                  num_captions_fid=num_captions_fid)\n\n        if self.offload_model:\n            self.image_question_matching_model.to(self.question_answering_model.device)\n            self.image_captioning_model.to(self.question_answering_model.device)\n\n        return pred_answers, samples['captions'], samples['gradcams']\n\n    @classmethod\n    def from_config(cls, model_config):\n        itm_config = model_config.image_question_matching_model\n        cap_config = model_config.image_captioning_model\n        qa_config = model_config.question_answering_model\n\n        itm_cls = registry.get_model_class(itm_config.arch)\n        cap_cls = registry.get_model_class(cap_config.arch)\n        qa_cls = registry.get_model_class(qa_config.arch)\n\n        image_question_matching_model = itm_cls.from_config(itm_config)\n        image_captioning_model = cap_cls.from_config(cap_config)\n        question_answering_model = qa_cls.from_config(qa_config)\n\n        model = cls(image_question_matching_model=image_question_matching_model,\n                    image_captioning_model=image_captioning_model,\n                    question_answering_model=question_answering_model,\n                    offload_model= True if model_config.model_type == '3b' else False,\n                    )\n\n        return model"
  },
  {
    "path": "lavis/models/sevila_models/__init__.py",
    "content": ""
  },
  {
    "path": "lavis/models/sevila_models/sevila.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport logging\n\nimport copy\nimport torch\nimport torch.nn as nn\nfrom torch.cuda.amp import autocast as autocast\nfrom transformers import T5TokenizerFast, BertTokenizer\n\nfrom lavis.common.registry import registry\nfrom lavis.models.blip2_models.blip2 import Blip2Base, disabled_train\nfrom lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration\n\n@registry.register_model(\"sevila\")\nclass SeViLA(Blip2Base):\n    \"\"\"\n    BLIP2 T5 model.\n    Supported model types:\n        - pretrain_flant5xl: pretrained model with FlanT5-XL\n        - pretrain_flant5xxl: pretrained model with FlanT5-XXL\n        - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL\n    Usage:\n        >>> from lavis.models import load_model\n        >>> model = load_model(\"blip2_t5\", \"pretrain_flant5xl\")\n    \"\"\"\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"pretrain_flant5xl\": \"configs/models/blip2/blip2_pretrain_flant5xl.yaml\",\n        \"pretrain_flant5xxl\": \"configs/models/blip2/blip2_pretrain_flant5xxl.yaml\",\n        \"caption_coco_flant5xl\": \"configs/models/blip2/blip2_caption_flant5xl.yaml\",\n    }\n\n    def __init__( self, img_size=224, drop_path_rate=0,\n        use_grad_checkpoint=False, vit_precision=\"fp16\", freeze_vit=True,\n        num_query_token=32, t5_model=\"google/flan-t5-xl\", prompt=\"\",\n        max_txt_len=32, frame_num=8, answer_num=5, apply_lemmatizer=False, task='qa'):\n        \"\"\"\n        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.\n        \"\"\"\n        super().__init__()\n        \n        self.task = task\n        \n        # vision backbone\n        self.visual_encoder, self.ln_vision, self.ln_vision_loc = self.init_vision_encoder_sevila(\n            img_size, drop_path_rate, use_grad_checkpoint, vit_precision)\n\n        # freeze ViT\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                param.requires_grad = False         \n            self.visual_encoder = self.visual_encoder.eval()\n            self.visual_encoder.train = disabled_train\n            logging.info(\"freeze vision encoder\")\n            \n        # text backbone\n        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)\n        t5_config = T5Config.from_pretrained(t5_model)\n        t5_config.dense_act_fn = \"gelu\"\n        self.t5_model = T5ForConditionalGeneration.from_pretrained(\n        t5_model, config=t5_config)\n\n        # freeze T5\n        for name, param in self.t5_model.named_parameters():\n            param.requires_grad = False\n            param.data = param.data.bfloat16() \n\n        # Q-Former for Answerer\n        self.Qformer, self.query_tokens = self.init_Qformer(\n        num_query_token, self.visual_encoder.num_features)\n        self.Qformer.cls = None\n        self.Qformer.bert.embeddings.word_embeddings = None\n        self.Qformer.bert.embeddings.position_embeddings = None\n        for layer in self.Qformer.bert.encoder.layer:\n            layer.output = None\n            layer.intermediate = None\n        self.num_query_token = num_query_token\n        self.t5_proj = nn.Linear(\n        self.Qformer.config.hidden_size, self.t5_model.config.hidden_size)\n        \n        # Q-Former for Localizer\n        if 'loc' in task:\n            self.Qformer_loc, self.query_tokens_loc = self.init_Qformer(\n            num_query_token, self.visual_encoder.num_features)\n\n            self.Qformer_loc.cls = None\n            self.Qformer_loc.bert.embeddings.word_embeddings = None\n            self.Qformer_loc.bert.embeddings.position_embeddings = None\n            for layer in self.Qformer_loc.bert.encoder.layer:\n                layer.output = None\n                layer.intermediate = None\n            self.t5_proj_loc = nn.Linear(\n            self.Qformer_loc.config.hidden_size, self.t5_model.config.hidden_size\n            )\n            \n        self.max_txt_len = 77\n        answer_id = [71, 272, 205, 309, 262] # A B C D E\n        self.answer_id = answer_id[:answer_num]\n        self.yes_id, self.no_id = 4273, 150\n        \n        self._apply_lemmatizer = apply_lemmatizer\n        self._lemmatizer = None\n        \n        self.frame_num = frame_num\n        self.ANS_MAP = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}\n        self.frame_prefix = ['Frame: ']\n        self.vid_prefix = ['Frame {}: '.format(str(i+1)) for i in range(frame_num)]\n        \n        \n        if 'freeze_qa' in task:\n            for name, param in self.Qformer.named_parameters():\n                param.requires_grad = False\n            self.query_tokens.requires_grad = False\n            self.t5_proj.requires_grad = False\n\n        if 'freeze_loc' in task:\n            for name, param in self.Qformer_loc.named_parameters():\n                param.requires_grad = False\n            self.query_tokens_loc.requires_grad = False\n            self.t5_proj_loc.requires_grad = False\n            \n    def forward(self, samples,\n        use_nucleus_sampling=False,\n        num_beams=5, max_length=30,\n        min_length=1, top_p=0.9,\n        repetition_penalty=1.0, length_penalty=1.0,\n        num_captions=1, temperature=1,):\n\n        image = samples[\"video\"]\n        \n        b, t, c, w, h = image.shape     \n        image = image.reshape(-1, c, w, h)\n        image_embeds = self.visual_encoder(image) \n        _, n, _ = image_embeds.shape\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n        \n        # Localizer self-refinement\n        if 'train_loc' in self.task:\n\n            # ========= Generate pseudo labels by frozen answerer ============\n            with torch.no_grad():\n                \n                image_embeds_, image_atts_ = image_embeds.detach().clone(), image_atts.detach().clone()\n                image_embeds_ = self.ln_vision(image_embeds_)\n                \n                query_tokens_qa = self.query_tokens.expand(image_embeds_.shape[0], -1, -1)\n                query_output_qa = self.Qformer.bert(\n                    query_embeds=query_tokens_qa, encoder_hidden_states=image_embeds_,\n                    encoder_attention_mask=image_atts_, return_dict=True)\n                inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n                atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                text_input_qa = samples['qa_input']\n                answer = samples['qa_output']\n                ans_idx = [self.ANS_MAP[a[-1]] for a in answer]\n\n                with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n                    # Frame Prefix\n                    frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",\n                        ).to(image.device) # \n                    frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                    frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                    # Question, options input\n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    input_ids_qa = torch.repeat_interleave(input_tokens_qa.input_ids, t, 0)\n                    input_attention_mask_qa = torch.repeat_interleave(input_tokens_qa.attention_mask, t, 0)\n\n                    # Output target\n                    output_tokens_qa = self.t5_tokenizer(\n                        answer, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    targets_qa = output_tokens_qa.input_ids.masked_fill(\n                        output_tokens_qa.input_ids == self.t5_tokenizer.pad_token_id, -100)\n                    output_tokens_mask_qa = torch.repeat_interleave(output_tokens_qa.attention_mask, t, dim=0)\n                    targets_qa = torch.repeat_interleave(targets_qa, t, dim=0)\n                    \n                    # input for QA\n                    frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_ids_qa)\n                    inputs_embeds_qa = torch.cat([frame_predix_embed, inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([frame_prefix_mask, atts_t5_qa, input_attention_mask_qa], dim=1)\n\n                    outputs_embed_qa = self.t5_model(\n                        inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                        decoder_attention_mask=output_tokens_mask_qa, return_dict=True, labels=targets_qa)\n                    pred_logits_qa = outputs_embed_qa.logits.detach()\n                    pred_logits_qa = pred_logits_qa[:, 1, self.answer_id] # b*t, 5\n                    pred_ans = torch.argmax(pred_logits_qa, dim=-1)\n                    pred_ans = pred_ans.reshape(b, -1) # b, t\n                    # print('pred_ans', pred_ans)\n                    pseudo_label = []\n                    for i, preds in enumerate(pred_ans):\n                        for p in preds:\n                            if p == ans_idx[i]:\n                                pseudo_label.append('yes')\n                            else:\n                                pseudo_label.append('no')\n            # ================================================================\n                \n            # ============== Train localizer with pseudo labels =================\n            text_input_loc = samples['loc_input']\n            query_tokens_loc = self.query_tokens_loc.expand(image_embeds.shape[0], -1, -1)\n            image_embeds = self.ln_vision_loc(image_embeds)\n            \n            query_output_loc = self.Qformer_loc.bert(\n                query_embeds=query_tokens_loc, encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts, return_dict=True) # bt, n, c\n            inputs_t5_loc = self.t5_proj_loc(query_output_loc.last_hidden_state) # bt, n, c\n            atts_t5_loc = torch.ones(inputs_t5_loc.size()[:-1], dtype=torch.long).to(image.device)\n            with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n                frame_prefix = self.t5_tokenizer(\n                    self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                    truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device) \n                frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n\n                input_tokens_loc = self.t5_tokenizer(\n                    text_input_loc, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                input_ids_loc = torch.repeat_interleave(input_tokens_loc.input_ids, t, 0)\n                input_attention_mask_loc = torch.repeat_interleave(input_tokens_loc.attention_mask, t, 0)\n                inputs_embeds_loc = self.t5_model.encoder.embed_tokens(input_ids_loc)\n                    \n                inputs_embeds_loc = torch.cat([frame_predix_embed, inputs_t5_loc, inputs_embeds_loc], dim=1)\n                encoder_atts_loc = torch.cat([frame_prefix_mask, atts_t5_loc, input_attention_mask_loc], dim=1)\n\n                output_tokens_loc = self.t5_tokenizer(\n                    pseudo_label, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                targets_loc = output_tokens_loc.input_ids.masked_fill(\n                    output_tokens_loc.input_ids == self.t5_tokenizer.pad_token_id, -100)\n                output_tokens_loc_mask = output_tokens_loc.attention_mask\n                \n                outputs_loc = self.t5_model(\n                    inputs_embeds=inputs_embeds_loc, attention_mask=encoder_atts_loc,\n                    decoder_attention_mask=output_tokens_loc_mask,\n                    return_dict=True, labels=targets_loc)\n                loss = outputs_loc.loss\n                                \n            return {\"loss\": loss}\n        \n        # Finetune answerer with localizer\n        elif 'train_qa_with_loc' in self.task:\n            # frame selection\n            with torch.no_grad():\n                image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n                image_embeds_, image_atts_ = image_embeds.detach().clone(), image_atts.detach().clone()\n                image_embeds_ = self.ln_vision_loc(image_embeds_)\n            \n                text_input_loc = samples['loc_input']\n                query_tokens_loc = self.query_tokens_loc.expand(image_embeds_.shape[0], -1, -1)\n                query_output_loc = self.Qformer_loc.bert(\n                    query_embeds=query_tokens_loc, encoder_hidden_states=image_embeds_,\n                    encoder_attention_mask=image_atts_, return_dict=True)\n                inputs_t5_loc = self.t5_proj_loc(query_output_loc.last_hidden_state)\n\n                atts_t5_loc = torch.ones(inputs_t5_loc.size()[:-1], dtype=torch.long).to(image.device)\n                with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n\n                    frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                    frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                    frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                    input_tokens_loc = self.t5_tokenizer(\n                        text_input_loc, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    input_ids_loc = torch.repeat_interleave(input_tokens_loc.input_ids, t, 0)\n                    input_attention_mask_loc = torch.repeat_interleave(input_tokens_loc.attention_mask, t, 0)\n                    inputs_embeds_loc = self.t5_model.encoder.embed_tokens(input_ids_loc)              \n                    inputs_embeds_loc = torch.cat([frame_predix_embed, inputs_t5_loc, inputs_embeds_loc], dim=1)\n                    encoder_atts_loc = torch.cat([frame_prefix_mask, atts_t5_loc, input_attention_mask_loc], dim=1)\n    \n                    outputs_loc = self.t5_model.generate(\n                        inputs_embeds=inputs_embeds_loc, attention_mask=encoder_atts_loc,\n                        do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=1,\n                        max_new_tokens=max_length, min_length=min_length, repetition_penalty=repetition_penalty,\n                        length_penalty=length_penalty, num_return_sequences=num_captions,\n                        return_dict_in_generate=True, output_hidden_states=True, output_scores=True)\n                            \n                    pred_logits_loc = outputs_loc.scores[0]\n                    loc_yes = pred_logits_loc[:, self.yes_id]\n                    loc_yes = loc_yes.reshape(b, -1)\n                    \n            text_input_qa = samples['qa_input']\n            answer = samples['qa_output'] # Option A ...\n            select_frames_idx = torch.topk(loc_yes, self.frame_num, dim=-1).indices.tolist()\n            sorted_frames_idx = []\n            image_embeds = self.ln_vision(image_embeds)\n            image_embeds = image_embeds.reshape(b, t, n, -1)\n            for frames in select_frames_idx:\n                sorted_frames_idx.append(sorted(frames))\n            select_frames = []\n            for i, fs in enumerate(sorted_frames_idx): \n                video = []\n                for j, f in enumerate(fs):\n                    video.append(image_embeds[i][f])\n                video = torch.stack(video, dim=0) # 4, n , -1\n                select_frames.append(video)\n                    \n            select_frames = torch.stack(select_frames, dim=0) # b 4, n , -1\n            select_frames = select_frames.reshape(-1, select_frames.shape[-2], select_frames.shape[-1])\n            image_atts = torch.ones(select_frames.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n            query_tokens_qa = self.query_tokens.expand(select_frames.shape[0], -1, -1)\n            query_output_qa = self.Qformer.bert(\n                query_embeds=query_tokens_qa, encoder_hidden_states=select_frames,\n                encoder_attention_mask=image_atts, return_dict=True)\n            inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n            inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-2], inputs_t5_qa.shape[-1])\n            atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n            \n            with torch.cuda.amp.autocast(dtype=torch.bfloat16):        \n                vid_prefix = self.t5_tokenizer(\n                    self.vid_prefix, padding=\"longest\", add_special_tokens=False,\n                    truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                vid_prefix_id = torch.repeat_interleave(vid_prefix.input_ids.unsqueeze(0), b, 0)\n                vid_prefix_mask = torch.repeat_interleave(vid_prefix.attention_mask.unsqueeze(0), b, 0)\n                vid_prefix_embed = self.t5_model.encoder.embed_tokens(vid_prefix_id) # b t n_word c\n                        \n                inputs_t5_qa = torch.cat([vid_prefix_embed, inputs_t5_qa], dim=2) # b, t, n_word + m, c\n                atts_t5_qa = torch.cat([vid_prefix_mask, atts_t5_qa], dim=2) # b, t, n_word + m \n                inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-1])\n                atts_t5_qa = atts_t5_qa.reshape(b, -1)\n                        \n                input_tokens_qa = self.t5_tokenizer(\n                    text_input_qa, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids) \n                inputs_embeds_qa = torch.cat([inputs_t5_qa, inputs_embeds_qa], dim=1)\n                encoder_atts_qa = torch.cat([atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                \n                output_tokens_qa = self.t5_tokenizer(\n                    answer, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                targets_qa = output_tokens_qa.input_ids.masked_fill(\n                    output_tokens_qa.input_ids == self.t5_tokenizer.pad_token_id, -100)\n                output_tokens_mask_qa = output_tokens_qa.attention_mask\n                \n                outputs_qa = self.t5_model(\n                    inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                    decoder_attention_mask=output_tokens_mask_qa, return_dict=True, labels=targets_qa)\n                loss = outputs_qa.loss\n                \n                return {\"loss\": loss}\n        \n        # finetune answerer with random frames\n        elif 'loc' not in self.task or 'train_qa_wo_loc' in self.task:\n            #pass\n            query_tokens_qa = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n            image_embeds = self.ln_vision(image_embeds)\n            \n            query_output_qa = self.Qformer.bert(\n                query_embeds=query_tokens_qa, encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts, return_dict=True)\n            inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n            text_input_qa = samples['qa_input'] \n            answer = samples['qa_output'] \n\n            with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n                    # Frame Prefix\n                if 'qa_vid' not in self.task:\n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device) \n                    frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len,return_tensors=\"pt\",\n                        ).to(image.device) \n                    frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                    frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                    # Question, Options input\n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    input_ids_qa = torch.repeat_interleave(input_tokens_qa.input_ids, t, 0)\n                    input_attention_mask_qa = torch.repeat_interleave(input_tokens_qa.attention_mask, t, 0)\n\n                    # Output target\n                    output_tokens_qa = self.t5_tokenizer(\n                        answer, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    targets_qa = output_tokens_qa.input_ids.masked_fill(\n                        output_tokens_qa.input_ids == self.t5_tokenizer.pad_token_id, -100)\n                    output_tokens_mask_qa = torch.repeat_interleave(output_tokens_qa.attention_mask, t, dim=0)\n                    targets_qa = torch.repeat_interleave(targets_qa, t, dim=0)\n                    \n                    # input for QA\n                    frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_ids_qa)\n                    inputs_embeds_qa = torch.cat([frame_predix_embed, inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([frame_prefix_mask, atts_t5_qa, input_attention_mask_qa], dim=1)\n                else:\n                    vid_prefix = self.t5_tokenizer(\n                        self.vid_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                    vid_prefix_id = torch.repeat_interleave(vid_prefix.input_ids.unsqueeze(0), b, 0)\n                    vid_prefix_mask = torch.repeat_interleave(vid_prefix.attention_mask.unsqueeze(0), b, 0)\n                    vid_prefix_embed = self.t5_model.encoder.embed_tokens(vid_prefix_id) # b t n_word c\n                    \n                    inputs_t5_qa = inputs_t5_qa.reshape(b, t, inputs_t5_qa.shape[-2], -1) # b, t, m ,c\n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                    \n                    inputs_t5_qa = torch.cat([vid_prefix_embed, inputs_t5_qa], dim=2) # b, t, n_word + m, c\n                    atts_t5_qa = torch.cat([vid_prefix_mask, atts_t5_qa], dim=2) # b, t, n_word + m \n                    inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-1])\n                    atts_t5_qa = atts_t5_qa.reshape(b, -1)\n                    \n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids) \n                    inputs_embeds_qa = torch.cat([inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                    \n                    output_tokens_qa = self.t5_tokenizer(\n                        answer, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    targets_qa = output_tokens_qa.input_ids.masked_fill(\n                        output_tokens_qa.input_ids == self.t5_tokenizer.pad_token_id, -100)\n                    output_tokens_mask_qa = output_tokens_qa.attention_mask\n\n                outputs_qa = self.t5_model(\n                    inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                    decoder_attention_mask=output_tokens_mask_qa, return_dict=True, labels=targets_qa)\n                loss = outputs_qa.loss\n                \n                return {\"loss\": loss}\n        \n\n    @torch.no_grad()\n    def generate(self,\n        samples,\n        use_nucleus_sampling=False,\n        num_beams=5, max_length=30,\n        min_length=1, top_p=0.9,\n        repetition_penalty=1.0, length_penalty=1.0,\n        num_captions=1, temperature=1,):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        out = {}\n        image, qid = samples[\"video\"], samples['question_id']\n        text_input_qa, answer = samples['qa_input'], samples['qa_output']\n        \n        # uniform sampling\n        if 'loc' not in self.task or 'uni_eval' in self.task:\n            b, t, c, w, h = image.shape        \n            image = image.reshape(-1, c, w, h)\n            with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n                image_embeds = self.ln_vision(self.visual_encoder(image)) # bt, n, c\n            _, n, _ = image_embeds.shape\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n            \n            query_tokens_qa = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n            query_output_qa = self.Qformer.bert(\n                query_embeds=query_tokens_qa, encoder_hidden_states=image_embeds,\n                encoder_attention_mask=image_atts, return_dict=True)\n            inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n            \n            with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n                    # Frame Prefix\n                if 'vid' not in self.task: \n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device) \n                    frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                    frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                    frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    input_ids_qa = torch.repeat_interleave(input_tokens_qa.input_ids, t, 0)\n                    input_attention_mask_qa = torch.repeat_interleave(input_tokens_qa.attention_mask, t, 0)\n                                        \n                    # input for QA\n                    frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_ids_qa)\n                    inputs_embeds_qa = torch.cat([frame_predix_embed, inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([frame_prefix_mask, atts_t5_qa, input_attention_mask_qa], dim=1)\n                \n                elif 'qa_vid' in self.task:\n                    vid_prefix = self.t5_tokenizer(\n                        self.vid_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                    vid_prefix_id = torch.repeat_interleave(vid_prefix.input_ids.unsqueeze(0), b, 0)\n                    vid_prefix_mask = torch.repeat_interleave(vid_prefix.attention_mask.unsqueeze(0), b, 0)\n                    vid_prefix_embed = self.t5_model.encoder.embed_tokens(vid_prefix_id) # b t n_word c\n                    \n                    inputs_t5_qa = inputs_t5_qa.reshape(b, t, inputs_t5_qa.shape[-2], -1) # b, t, m ,c\n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                    \n                    inputs_t5_qa = torch.cat([vid_prefix_embed, inputs_t5_qa], dim=2) # b, t, n_word + m, c\n                    atts_t5_qa = torch.cat([vid_prefix_mask, atts_t5_qa], dim=2) # b, t, n_word + m \n                    inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-1])\n                    atts_t5_qa = atts_t5_qa.reshape(b, -1)\n                    \n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids)\n                    inputs_embeds_qa = torch.cat([inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n\n                outputs_qa = self.t5_model.generate(\n                    inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                    do_sample=use_nucleus_sampling, top_p=top_p,\n                    temperature=temperature, num_beams=1,\n                    max_new_tokens=max_length, min_length=min_length,\n                    repetition_penalty=repetition_penalty, length_penalty=length_penalty,\n                    num_return_sequences=num_captions, return_dict_in_generate=True,\n                    output_hidden_states=True, output_scores=True)\n                try:\n                    pred_logits_qa = outputs_qa.scores[1]\n                except:\n                    pred_logits_qa = outputs_qa.scores[0]\n                pred_logits_qa = pred_logits_qa[:, self.answer_id] # b, 5\n                pred_ans = torch.argmax(pred_logits_qa, dim=-1).cpu().tolist() \n        \n        # inference with localizer             \n        else:\n            \n            b, t, c, w, h = image.shape        \n            image = image.reshape(-1, c, w, h)\n            with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n                image_embeds = self.visual_encoder(image) # bt, n, c\n                \n            _, n, _ = image_embeds.shape\n            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n            image_embeds_, image_atts_ = image_embeds.detach().clone(), image_atts.detach().clone()\n            image_embeds_ = self.ln_vision_loc(image_embeds_)\n            \n            text_input_loc = samples['loc_input'] # Q + Prompt: Is this a good frame can answer the question?\n            query_tokens_loc = self.query_tokens_loc.expand(image_embeds_.shape[0], -1, -1)\n            query_output_loc = self.Qformer_loc.bert(\n                query_embeds=query_tokens_loc, encoder_hidden_states=image_embeds_,\n                encoder_attention_mask=image_atts_, return_dict=True)\n            inputs_t5_loc = self.t5_proj_loc(query_output_loc.last_hidden_state)\n\n            atts_t5_loc = torch.ones(inputs_t5_loc.size()[:-1], dtype=torch.long).to(image.device)\n            with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n\n                frame_prefix = self.t5_tokenizer(\n                    self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                    truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device) # \n                #print('frame_prefix 1', frame_prefix.input_ids.shape) 8, 4\n                frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n                frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n                frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                input_tokens_loc = self.t5_tokenizer(\n                    text_input_loc, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                #print('input_ids_loc.input_ids', input_tokens_loc.input_ids)\n                input_ids_loc = torch.repeat_interleave(input_tokens_loc.input_ids, t, 0)\n                #print('input_ids_loc', input_ids_loc)\n                input_attention_mask_loc = torch.repeat_interleave(input_tokens_loc.attention_mask, t, 0)\n                inputs_embeds_loc = self.t5_model.encoder.embed_tokens(input_ids_loc)              \n                inputs_embeds_loc = torch.cat([frame_predix_embed, inputs_t5_loc, inputs_embeds_loc], dim=1)\n                encoder_atts_loc = torch.cat([frame_prefix_mask, atts_t5_loc, input_attention_mask_loc], dim=1)\n    \n                outputs_loc = self.t5_model.generate(\n                    inputs_embeds=inputs_embeds_loc, attention_mask=encoder_atts_loc,\n                    do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=1,\n                    max_new_tokens=max_length, min_length=min_length, repetition_penalty=repetition_penalty,\n                    length_penalty=length_penalty, num_return_sequences=num_captions,\n                    return_dict_in_generate=True, output_hidden_states=True, output_scores=True)\n                        \n                pred_logits_loc = outputs_loc.scores[0]\n                loc_yes = pred_logits_loc[:, self.yes_id]\n                loc_yes = loc_yes.reshape(b, -1)\n                if 'qa_vid' in self.task:\n                    select_frames_idx = torch.topk(loc_yes, self.frame_num, dim=-1).indices.tolist()\n                    sorted_frames_idx = []\n                    image_embeds = self.ln_vision(image_embeds)\n                    image_embeds = image_embeds.reshape(b, t, n, -1)\n                    for frames in select_frames_idx:\n                        sorted_frames_idx.append(sorted(frames))\n                    out['frame_idx'] = sorted_frames_idx\n                    select_frames = []\n                    for i, fs in enumerate(sorted_frames_idx): \n                        video = []\n                        for j, f in enumerate(fs):\n                            video.append(image_embeds[i][f])\n                        video = torch.stack(video, dim=0)\n                        select_frames.append(video)\n                    \n                    select_frames = torch.stack(select_frames, dim=0) # b 4, n , -1\n                    select_frames = select_frames.reshape(-1, select_frames.shape[-2], select_frames.shape[-1])\n                    image_atts = torch.ones(select_frames.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n                    query_tokens_qa = self.query_tokens.expand(select_frames.shape[0], -1, -1)\n                    query_output_qa = self.Qformer.bert(\n                        query_embeds=query_tokens_qa, encoder_hidden_states=select_frames,\n                        encoder_attention_mask=image_atts, return_dict=True)\n                    inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n                    inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-2], inputs_t5_qa.shape[-1])\n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                    \n                    vid_prefix = self.t5_tokenizer(\n                        self.vid_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                    vid_prefix_id = torch.repeat_interleave(vid_prefix.input_ids.unsqueeze(0), b, 0)\n                    vid_prefix_mask = torch.repeat_interleave(vid_prefix.attention_mask.unsqueeze(0), b, 0)\n                    vid_prefix_embed = self.t5_model.encoder.embed_tokens(vid_prefix_id) # b t n_word c\n                    \n                    inputs_t5_qa = torch.cat([vid_prefix_embed, inputs_t5_qa], dim=2) # b, t, n_word + m, c\n                    atts_t5_qa = torch.cat([vid_prefix_mask, atts_t5_qa], dim=2) # b, t, n_word + m \n                    inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-1])\n                    atts_t5_qa = atts_t5_qa.reshape(b, -1)\n                    \n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids) \n                    inputs_embeds_qa = torch.cat([inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                    \n                else:\n                    select_frames_idx = torch.argmax(loc_yes, -1)\n                    select_frames = []\n                    image_embeds = self.ln_vision(image_embeds)\n                    image_embeds = image_embeds.reshape(b, t, n, -1)\n                    for i, f in enumerate(select_frames_idx):\n                        select_frames.append(image_embeds[i][f])\n                        \n                    select_frames = torch.stack(select_frames, dim=0)\n                    image_atts = torch.ones(select_frames.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n                    query_tokens_qa = self.query_tokens.expand(select_frames.shape[0], -1, -1)\n                    query_output_qa = self.Qformer.bert(\n                        query_embeds=query_tokens_qa, encoder_hidden_states=select_frames,\n                        encoder_attention_mask=image_atts, return_dict=True)\n                    inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n                    atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                    \n                    frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False, \n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device) # \n                    frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b, 0)\n                    frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b, 0)\n\n                    input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n\n                    frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                    inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids)\n\n                    inputs_embeds_qa = torch.cat([frame_predix_embed, inputs_t5_qa, inputs_embeds_qa], dim=1)\n                    encoder_atts_qa = torch.cat([frame_prefix_mask, atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                            \n                outputs_qa = self.t5_model.generate(\n                    inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                    do_sample=use_nucleus_sampling, top_p=top_p,\n                    temperature=temperature, num_beams=1,\n                    max_new_tokens=max_length, min_length=min_length,\n                    repetition_penalty=repetition_penalty, length_penalty=length_penalty,\n                    num_return_sequences=num_captions, return_dict_in_generate=True,\n                    output_hidden_states=True, output_scores=True)\n                pred_logits_qa = outputs_qa.scores[1]\n                pred_logits_qa = pred_logits_qa[:, self.answer_id] # b, 5\n                pred_ans = torch.argmax(pred_logits_qa, dim=-1).cpu().tolist()\n        \n        out['output_text'] = pred_ans\n        if 'qa_vid' not in self.task: \n            out['temp_idx'] = [j for i in range(b) for j in range(t)]\n            out['answer'] = [a for a in answer for i in range(t)]\n            out['qid'] = [q for q in qid for i in range(t)]\n        else:\n            out['answer'] = answer\n            out['qid'] = qid\n\n        return out\n    \n    @torch.no_grad()\n    def generate_demo(self,\n        video,\n        text_input_qa,\n        text_input_loc,\n        keyframe_num,\n        qid='demo',\n        use_nucleus_sampling=False,\n        num_beams=5, max_length=30,\n        min_length=1, top_p=0.9,\n        repetition_penalty=1.0, length_penalty=1.0,\n        num_captions=1, temperature=1,):\n        \"\"\"\n        Args:\n            samples (dict): A dictionary containing the following keys:\n                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)\n            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.\n            num_beams (int): Number of beams for beam search. 1 means no beam search.\n            max_length (int): The maximum length of the sequence to be generated.\n            min_length (int): The minimum length of the sequence to be generated.\n            top_p (float): The cumulative probability for nucleus sampling.\n            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.\n            num_captions (int): Number of captions to be generated for each image.\n        Returns:\n            captions (list): A list of strings of length batch_size * num_captions.\n        \"\"\"\n        out = {}\n        image, qid = video, qid\n        text_input_qa, answer = text_input_qa, 0\n        \n        # inference with localizer             \n            \n        b, t, c, w, h = image.shape        \n        image = image.reshape(-1, c, w, h)\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.visual_encoder(image) # bt, n, c\n                \n        _, n, _ = image_embeds.shape\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n        image_embeds_, image_atts_ = image_embeds.detach().clone(), image_atts.detach().clone()\n        image_embeds_ = self.ln_vision_loc(image_embeds_)\n            \n        text_input_loc = text_input_loc # Q + Prompt: Is this a good frame can answer the question?\n        query_tokens_loc = self.query_tokens_loc.expand(image_embeds_.shape[0], -1, -1)\n        query_output_loc = self.Qformer_loc.bert(\n            query_embeds=query_tokens_loc, encoder_hidden_states=image_embeds_,\n            encoder_attention_mask=image_atts_, return_dict=True)\n        inputs_t5_loc = self.t5_proj_loc(query_output_loc.last_hidden_state)\n\n        atts_t5_loc = torch.ones(inputs_t5_loc.size()[:-1], dtype=torch.long).to(image.device)\n        with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n\n            frame_prefix = self.t5_tokenizer(\n                self.frame_prefix, padding=\"longest\", add_special_tokens=False,\n                truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device) # \n                #print('frame_prefix 1', frame_prefix.input_ids.shape) 8, 4\n            frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b*t, 0)\n            frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b*t, 0)\n            frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n            input_tokens_loc = self.t5_tokenizer(\n                text_input_loc, padding=\"longest\", truncation=True,\n                max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                #print('input_ids_loc.input_ids', input_tokens_loc.input_ids)\n            input_ids_loc = torch.repeat_interleave(input_tokens_loc.input_ids, t, 0)\n                #print('input_ids_loc', input_ids_loc)\n            input_attention_mask_loc = torch.repeat_interleave(input_tokens_loc.attention_mask, t, 0)\n            inputs_embeds_loc = self.t5_model.encoder.embed_tokens(input_ids_loc)              \n            inputs_embeds_loc = torch.cat([frame_predix_embed, inputs_t5_loc, inputs_embeds_loc], dim=1)\n            encoder_atts_loc = torch.cat([frame_prefix_mask, atts_t5_loc, input_attention_mask_loc], dim=1)\n    \n            outputs_loc = self.t5_model.generate(\n                inputs_embeds=inputs_embeds_loc, attention_mask=encoder_atts_loc,\n                do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=1,\n                max_new_tokens=max_length, min_length=min_length, repetition_penalty=repetition_penalty,\n                length_penalty=length_penalty, num_return_sequences=num_captions,\n                return_dict_in_generate=True, output_hidden_states=True, output_scores=True)\n                        \n            pred_logits_loc = outputs_loc.scores[0]\n            loc_yes = pred_logits_loc[:, self.yes_id]\n            loc_yes = loc_yes.reshape(b, -1)\n            if 'qa_vid' in self.task:\n                select_frames_idx = torch.topk(loc_yes, keyframe_num, dim=-1).indices.tolist()\n                sorted_frames_idx = []\n                image_embeds = self.ln_vision(image_embeds)\n                image_embeds = image_embeds.reshape(b, t, n, -1)\n                for frames in select_frames_idx:\n                    sorted_frames_idx.append(sorted(frames))\n                out['frame_idx'] = sorted_frames_idx\n                select_frames = []\n                for i, fs in enumerate(sorted_frames_idx): \n                    video = []\n                    for j, f in enumerate(fs):\n                        video.append(image_embeds[i][f])\n                    video = torch.stack(video, dim=0)\n                    select_frames.append(video)\n                    \n                select_frames = torch.stack(select_frames, dim=0) # b 4, n , -1\n                select_frames = select_frames.reshape(-1, select_frames.shape[-2], select_frames.shape[-1])\n                image_atts = torch.ones(select_frames.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n                query_tokens_qa = self.query_tokens.expand(select_frames.shape[0], -1, -1)\n                query_output_qa = self.Qformer.bert(\n                    query_embeds=query_tokens_qa, encoder_hidden_states=select_frames,\n                    encoder_attention_mask=image_atts, return_dict=True)\n                inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n                inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-2], inputs_t5_qa.shape[-1])\n                atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                \n                vid_prefix = self.t5_tokenizer(\n                        self.vid_prefix, padding=\"longest\", add_special_tokens=False,\n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\",).to(image.device) # \n                vid_prefix_id = torch.repeat_interleave(vid_prefix.input_ids.unsqueeze(0), b, 0)\n                vid_prefix_mask = torch.repeat_interleave(vid_prefix.attention_mask.unsqueeze(0), b, 0)\n                vid_prefix_embed = self.t5_model.encoder.embed_tokens(vid_prefix_id) # b t n_word c\n                    \n                inputs_t5_qa = torch.cat([vid_prefix_embed, inputs_t5_qa], dim=2) # b, t, n_word + m, c\n                atts_t5_qa = torch.cat([vid_prefix_mask, atts_t5_qa], dim=2) # b, t, n_word + m \n                inputs_t5_qa = inputs_t5_qa.reshape(b, -1, inputs_t5_qa.shape[-1])\n                atts_t5_qa = atts_t5_qa.reshape(b, -1)\n                    \n                input_tokens_qa = self.t5_tokenizer(\n                        text_input_qa, padding=\"longest\", truncation=True,\n                        max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n                inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids) \n                inputs_embeds_qa = torch.cat([inputs_t5_qa, inputs_embeds_qa], dim=1)\n                encoder_atts_qa = torch.cat([atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                    \n            else:\n                select_frames_idx = torch.argmax(loc_yes, -1)\n                select_frames = []\n                image_embeds = self.ln_vision(image_embeds)\n                image_embeds = image_embeds.reshape(b, t, n, -1)\n                for i, f in enumerate(select_frames_idx):\n                    select_frames.append(image_embeds[i][f])\n                        \n                select_frames = torch.stack(select_frames, dim=0)\n                image_atts = torch.ones(select_frames.size()[:-1], dtype=torch.long).to(image.device) # bt n c\n                query_tokens_qa = self.query_tokens.expand(select_frames.shape[0], -1, -1)\n                query_output_qa = self.Qformer.bert(\n                    query_embeds=query_tokens_qa, encoder_hidden_states=select_frames,\n                        encoder_attention_mask=image_atts, return_dict=True)\n                inputs_t5_qa = self.t5_proj(query_output_qa.last_hidden_state)\n                atts_t5_qa = torch.ones(inputs_t5_qa.size()[:-1], dtype=torch.long).to(image.device)\n                    \n                frame_prefix = self.t5_tokenizer(\n                        self.frame_prefix, padding=\"longest\", add_special_tokens=False, \n                        truncation=True, max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device) # \n                frame_prefix_id = torch.repeat_interleave(frame_prefix.input_ids, b, 0)\n                frame_prefix_mask = torch.repeat_interleave(frame_prefix.attention_mask, b, 0)\n\n                input_tokens_qa = self.t5_tokenizer(\n                    text_input_qa, padding=\"longest\", truncation=True,\n                    max_length=self.max_txt_len, return_tensors=\"pt\").to(image.device)\n\n                frame_predix_embed = self.t5_model.encoder.embed_tokens(frame_prefix_id)\n                inputs_embeds_qa = self.t5_model.encoder.embed_tokens(input_tokens_qa.input_ids)\n\n                inputs_embeds_qa = torch.cat([frame_predix_embed, inputs_t5_qa, inputs_embeds_qa], dim=1)\n                encoder_atts_qa = torch.cat([frame_prefix_mask, atts_t5_qa, input_tokens_qa.attention_mask], dim=1)\n                            \n            outputs_qa = self.t5_model.generate(\n                    inputs_embeds=inputs_embeds_qa, attention_mask=encoder_atts_qa,\n                    do_sample=use_nucleus_sampling, top_p=top_p,\n                    temperature=temperature, num_beams=1,\n                    max_new_tokens=max_length, min_length=min_length,\n                    repetition_penalty=repetition_penalty, length_penalty=length_penalty,\n                    num_return_sequences=num_captions, return_dict_in_generate=True,\n                    output_hidden_states=True, output_scores=True)\n            pred_logits_qa = outputs_qa.scores[1]\n            pred_logits_qa = pred_logits_qa[:, self.answer_id] # b, 5\n            pred_ans = torch.argmax(pred_logits_qa, dim=-1).cpu().tolist()\n        \n        out['output_text'] = pred_ans\n        if 'qa_vid' not in self.task: \n            out['temp_idx'] = [j for i in range(b) for j in range(t)]\n            # out['answer'] = [a for a in answer for i in range(t)]\n            out['qid'] = [q for q in qid for i in range(t)]\n        else:\n            # out['answer'] = answer\n            out['qid'] = qid\n\n        return out\n\n    def predict_answers(\n        self,\n        samples,\n        num_beams=5,\n        inference_method=\"generate\",\n        max_len=10,\n        min_len=1,\n        num_ans_candidates=128,\n        answer_list=None,\n        prompt=\"\",\n        length_penalty=-1,\n        **kwargs\n    ):\n        image = samples[\"image\"]\n        with torch.cuda.amp.autocast(enabled=(self.device != torch.device(\"cpu\"))):\n            image_embeds = self.ln_vision(self.visual_encoder(image))\n        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n            image.device\n        )\n\n        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n        query_output = self.Qformer.bert(\n            query_embeds=query_tokens,\n            encoder_hidden_states=image_embeds,\n            encoder_attention_mask=image_atts,\n            return_dict=True,\n        )\n\n        inputs_t5 = self.t5_proj(query_output.last_hidden_state)\n        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n\n        if isinstance(samples[\"text_input\"], str):\n            samples[\"text_input\"] = [samples[\"text_input\"]]\n        if prompt:\n            text_input = [prompt.format(question) for question in samples[\"text_input\"]]\n        else:\n            text_input = samples[\"text_input\"]\n\n        input_tokens = self.t5_tokenizer(\n            text_input, padding=\"longest\", return_tensors=\"pt\"\n        ).to(image.device)\n\n        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n\n        device_type = \"cuda\" if \"cuda\" in str(self.device) else \"cpu\"\n        with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):\n            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n\n            outputs = self.t5_model.generate(\n                inputs_embeds=inputs_embeds,\n                attention_mask=encoder_atts,\n                do_sample=False,\n                num_beams=num_beams,\n                max_new_tokens=max_len,\n                min_length=min_len,\n                length_penalty=length_penalty,\n            )\n            output_text = self.t5_tokenizer.batch_decode(\n                outputs, skip_special_tokens=True\n            )\n\n        if self._apply_lemmatizer:\n            output_text = self._lemmatize(output_text)\n\n        return output_text\n\n    def _lemmatize(self, answers):\n        def apply(answer):\n            doc = self.lemmatizer(answer)\n\n            words = []\n            for token in doc:\n                if token.pos_ in [\"NOUN\", \"VERB\"]:\n                    words.append(token.lemma_)\n                else:\n                    words.append(token.text)\n            answer = \" \".join(words)\n\n            return answer\n\n        return [apply(answer) for answer in answers]\n\n    @property\n    def lemmatizer(self):\n        if self._lemmatizer is None:\n            try:\n                import spacy\n\n                self._lemmatizer = spacy.load(\"en_core_web_sm\")\n            except ImportError:\n                logging.error(\n                    \"\"\"\n                    Please install spacy and en_core_web_sm model to apply lemmatization.\n                    python -m spacy download en_core_web_sm\n                    OR\n                    import spacy.cli\n                    spacy.cli.download(\"en_core_web_sm\")\n                    \"\"\"\n                )\n                exit(1)\n\n        return self._lemmatizer\n\n    @classmethod\n    def from_config(cls, cfg):\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n        t5_model = cfg.get(\"t5_model\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n        frame_num = cfg.get(\"frame_num\", 8)\n        answer_num = cfg.get(\"answer_num\", 5) \n        apply_lemmatizer = cfg.get(\"apply_lemmatizer\", False)\n        task = cfg.get(\"task\", 'train_loc_freeze_qa')\n\n        model = cls(\n            img_size=img_size,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            num_query_token=num_query_token,\n            t5_model=t5_model,\n            prompt=prompt,\n            max_txt_len=max_txt_len,\n            apply_lemmatizer=apply_lemmatizer,\n            frame_num=frame_num,\n            answer_num=answer_num,\n            task=task,\n        )\n        model.load_checkpoint_from_config(cfg)\n        # for sevila with qvh pretraining\n        # need load blip-2 q-former ckpt to q-former_loc\n        if 'loc' in task and 'qvh' not in task:\n           model.load_qformer_loc()\n\n        return model"
  },
  {
    "path": "lavis/models/timesformer/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n"
  },
  {
    "path": "lavis/models/timesformer/conv2d_same.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n\n# Copyright 2020 Ross Wightman\n# Conv2d w/ Same Padding\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Tuple, Optional\n\nimport math\nfrom typing import List, Tuple\n\nfrom .vit_utils import is_static_pad, get_padding\n\n# Dynamically pad input x with 'SAME' padding for conv with specified args\ndef pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):\n    ih, iw = x.size()[-2:]\n    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(\n        iw, k[1], s[1], d[1]\n    )\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(\n            x,\n            [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],\n            value=value,\n        )\n    return x\n\n\n# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution\ndef get_same_padding(x: int, k: int, s: int, d: int):\n    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)\n\n\ndef get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:\n    dynamic = False\n    if isinstance(padding, str):\n        # for any string padding, the padding will be calculated for you, one of three ways\n        padding = padding.lower()\n        if padding == \"same\":\n            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact\n            if is_static_pad(kernel_size, **kwargs):\n                # static case, no extra overhead\n                padding = get_padding(kernel_size, **kwargs)\n            else:\n                # dynamic 'SAME' padding, has runtime/GPU memory overhead\n                padding = 0\n                dynamic = True\n        elif padding == \"valid\":\n            # 'VALID' padding, same as padding=0\n            padding = 0\n        else:\n            # Default to PyTorch style 'same'-ish symmetric padding\n            padding = get_padding(kernel_size, **kwargs)\n    return padding, dynamic\n\n\ndef conv2d_same(\n    x,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    stride: Tuple[int, int] = (1, 1),\n    padding: Tuple[int, int] = (0, 0),\n    dilation: Tuple[int, int] = (1, 1),\n    groups: int = 1,\n):\n    x = pad_same(x, weight.shape[-2:], stride, dilation)\n    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)\n\n\nclass Conv2dSame(nn.Conv2d):\n    \"\"\"Tensorflow like 'SAME' convolution wrapper for 2D convolutions\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=0,\n        dilation=1,\n        groups=1,\n        bias=True,\n    ):\n        super(Conv2dSame, self).__init__(\n            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias\n        )\n\n    def forward(self, x):\n        return conv2d_same(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n\ndef create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):\n    padding = kwargs.pop(\"padding\", \"\")\n    kwargs.setdefault(\"bias\", False)\n    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)\n    if is_dynamic:\n        return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)\n    else:\n        return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)\n"
  },
  {
    "path": "lavis/models/timesformer/features.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n\n# Copyright 2020 Ross Wightman\n\nfrom collections import OrderedDict, defaultdict\nfrom copy import deepcopy\nfrom functools import partial\nfrom typing import Dict, List, Tuple\n\nimport torch\nimport torch.nn as nn\n\n\nclass FeatureInfo:\n    def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):\n        prev_reduction = 1\n        for fi in feature_info:\n            # sanity check the mandatory fields, there may be additional fields depending on the model\n            assert \"num_chs\" in fi and fi[\"num_chs\"] > 0\n            assert \"reduction\" in fi and fi[\"reduction\"] >= prev_reduction\n            prev_reduction = fi[\"reduction\"]\n            assert \"module\" in fi\n        self.out_indices = out_indices\n        self.info = feature_info\n\n    def from_other(self, out_indices: Tuple[int]):\n        return FeatureInfo(deepcopy(self.info), out_indices)\n\n    def get(self, key, idx=None):\n        \"\"\"Get value by key at specified index (indices)\n        if idx == None, returns value for key at each output index\n        if idx is an integer, return value for that feature module index (ignoring output indices)\n        if idx is a list/tupple, return value for each module index (ignoring output indices)\n        \"\"\"\n        if idx is None:\n            return [self.info[i][key] for i in self.out_indices]\n        if isinstance(idx, (tuple, list)):\n            return [self.info[i][key] for i in idx]\n        else:\n            return self.info[idx][key]\n\n    def get_dicts(self, keys=None, idx=None):\n        \"\"\"return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)\"\"\"\n        if idx is None:\n            if keys is None:\n                return [self.info[i] for i in self.out_indices]\n            else:\n                return [{k: self.info[i][k] for k in keys} for i in self.out_indices]\n        if isinstance(idx, (tuple, list)):\n            return [\n                self.info[i] if keys is None else {k: self.info[i][k] for k in keys}\n                for i in idx\n            ]\n        else:\n            return (\n                self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}\n            )\n\n    def channels(self, idx=None):\n        \"\"\"feature channels accessor\"\"\"\n        return self.get(\"num_chs\", idx)\n\n    def reduction(self, idx=None):\n        \"\"\"feature reduction (output stride) accessor\"\"\"\n        return self.get(\"reduction\", idx)\n\n    def module_name(self, idx=None):\n        \"\"\"feature module name accessor\"\"\"\n        return self.get(\"module\", idx)\n\n    def __getitem__(self, item):\n        return self.info[item]\n\n    def __len__(self):\n        return len(self.info)\n\n\nclass FeatureHooks:\n    \"\"\"Feature Hook Helper\n    This module helps with the setup and extraction of hooks for extracting features from\n    internal nodes in a model by node name. This works quite well in eager Python but needs\n    redesign for torcscript.\n    \"\"\"\n\n    def __init__(self, hooks, named_modules, out_map=None, default_hook_type=\"forward\"):\n        # setup feature hooks\n        modules = {k: v for k, v in named_modules}\n        for i, h in enumerate(hooks):\n            hook_name = h[\"module\"]\n            m = modules[hook_name]\n            hook_id = out_map[i] if out_map else hook_name\n            hook_fn = partial(self._collect_output_hook, hook_id)\n            hook_type = h[\"hook_type\"] if \"hook_type\" in h else default_hook_type\n            if hook_type == \"forward_pre\":\n                m.register_forward_pre_hook(hook_fn)\n            elif hook_type == \"forward\":\n                m.register_forward_hook(hook_fn)\n            else:\n                assert False, \"Unsupported hook type\"\n        self._feature_outputs = defaultdict(OrderedDict)\n\n    def _collect_output_hook(self, hook_id, *args):\n        x = args[\n            -1\n        ]  # tensor we want is last argument, output for fwd, input for fwd_pre\n        if isinstance(x, tuple):\n            x = x[0]  # unwrap input tuple\n        self._feature_outputs[x.device][hook_id] = x\n\n    def get_output(self, device) -> Dict[str, torch.tensor]:\n        output = self._feature_outputs[device]\n        self._feature_outputs[device] = OrderedDict()  # clear after reading\n        return output\n\n\ndef _module_list(module, flatten_sequential=False):\n    # a yield/iter would be better for this but wouldn't be compatible with torchscript\n    ml = []\n    for name, module in module.named_children():\n        if flatten_sequential and isinstance(module, nn.Sequential):\n            # first level of Sequential containers is flattened into containing model\n            for child_name, child_module in module.named_children():\n                combined = [name, child_name]\n                ml.append((\"_\".join(combined), \".\".join(combined), child_module))\n        else:\n            ml.append((name, name, module))\n    return ml\n\n\ndef _get_feature_info(net, out_indices):\n    feature_info = getattr(net, \"feature_info\")\n    if isinstance(feature_info, FeatureInfo):\n        return feature_info.from_other(out_indices)\n    elif isinstance(feature_info, (list, tuple)):\n        return FeatureInfo(net.feature_info, out_indices)\n    else:\n        assert False, \"Provided feature_info is not valid\"\n\n\ndef _get_return_layers(feature_info, out_map):\n    module_names = feature_info.module_name()\n    return_layers = {}\n    for i, name in enumerate(module_names):\n        return_layers[name] = (\n            out_map[i] if out_map is not None else feature_info.out_indices[i]\n        )\n    return return_layers\n\n\nclass FeatureDictNet(nn.ModuleDict):\n    \"\"\"Feature extractor with OrderedDict return\n    Wrap a model and extract features as specified by the out indices, the network is\n    partially re-built from contained modules.\n    There is a strong assumption that the modules have been registered into the model in the same\n    order as they are used. There should be no reuse of the same nn.Module more than once, including\n    trivial modules like `self.relu = nn.ReLU`.\n    Only submodules that are directly assigned to the model class (`model.feature1`) or at most\n    one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.\n    All Sequential containers that are directly assigned to the original model will have their\n    modules assigned to this module with the name `model.features.1` being changed to `model.features_1`\n    Arguments:\n        model (nn.Module): model from which we will extract the features\n        out_indices (tuple[int]): model output indices to extract features for\n        out_map (sequence): list or tuple specifying desired return id for each out index,\n            otherwise str(index) is used\n        feature_concat (bool): whether to concatenate intermediate features that are lists or tuples\n            vs select element [0]\n        flatten_sequential (bool): whether to flatten sequential modules assigned to model\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        out_indices=(0, 1, 2, 3, 4),\n        out_map=None,\n        feature_concat=False,\n        flatten_sequential=False,\n    ):\n        super(FeatureDictNet, self).__init__()\n        self.feature_info = _get_feature_info(model, out_indices)\n        self.concat = feature_concat\n        self.return_layers = {}\n        return_layers = _get_return_layers(self.feature_info, out_map)\n        modules = _module_list(model, flatten_sequential=flatten_sequential)\n        remaining = set(return_layers.keys())\n        layers = OrderedDict()\n        for new_name, old_name, module in modules:\n            layers[new_name] = module\n            if old_name in remaining:\n                # return id has to be consistently str type for torchscript\n                self.return_layers[new_name] = str(return_layers[old_name])\n                remaining.remove(old_name)\n            if not remaining:\n                break\n        assert not remaining and len(self.return_layers) == len(\n            return_layers\n        ), f\"Return layers ({remaining}) are not present in model\"\n        self.update(layers)\n\n    def _collect(self, x) -> (Dict[str, torch.Tensor]):\n        out = OrderedDict()\n        for name, module in self.items():\n            x = module(x)\n            if name in self.return_layers:\n                out_id = self.return_layers[name]\n                if isinstance(x, (tuple, list)):\n                    # If model tap is a tuple or list, concat or select first element\n                    # FIXME this may need to be more generic / flexible for some nets\n                    out[out_id] = torch.cat(x, 1) if self.concat else x[0]\n                else:\n                    out[out_id] = x\n        return out\n\n    def forward(self, x) -> Dict[str, torch.Tensor]:\n        return self._collect(x)\n\n\nclass FeatureListNet(FeatureDictNet):\n    \"\"\"Feature extractor with list return\n    See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.\n    In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        out_indices=(0, 1, 2, 3, 4),\n        out_map=None,\n        feature_concat=False,\n        flatten_sequential=False,\n    ):\n        super(FeatureListNet, self).__init__(\n            model,\n            out_indices=out_indices,\n            out_map=out_map,\n            feature_concat=feature_concat,\n            flatten_sequential=flatten_sequential,\n        )\n\n    def forward(self, x) -> (List[torch.Tensor]):\n        return list(self._collect(x).values())\n\n\nclass FeatureHookNet(nn.ModuleDict):\n    \"\"\"FeatureHookNet\n    Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.\n    If `no_rewrite` is True, features are extracted via hooks without modifying the underlying\n    network in any way.\n    If `no_rewrite` is False, the model will be re-written as in the\n    FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.\n    FIXME this does not currently work with Torchscript, see FeatureHooks class\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        out_indices=(0, 1, 2, 3, 4),\n        out_map=None,\n        out_as_dict=False,\n        no_rewrite=False,\n        feature_concat=False,\n        flatten_sequential=False,\n        default_hook_type=\"forward\",\n    ):\n        super(FeatureHookNet, self).__init__()\n        assert not torch.jit.is_scripting()\n        self.feature_info = _get_feature_info(model, out_indices)\n        self.out_as_dict = out_as_dict\n        layers = OrderedDict()\n        hooks = []\n        if no_rewrite:\n            assert not flatten_sequential\n            if hasattr(model, \"reset_classifier\"):  # make sure classifier is removed?\n                model.reset_classifier(0)\n            layers[\"body\"] = model\n            hooks.extend(self.feature_info.get_dicts())\n        else:\n            modules = _module_list(model, flatten_sequential=flatten_sequential)\n            remaining = {\n                f[\"module\"]: f[\"hook_type\"] if \"hook_type\" in f else default_hook_type\n                for f in self.feature_info.get_dicts()\n            }\n            for new_name, old_name, module in modules:\n                layers[new_name] = module\n                for fn, fm in module.named_modules(prefix=old_name):\n                    if fn in remaining:\n                        hooks.append(dict(module=fn, hook_type=remaining[fn]))\n                        del remaining[fn]\n                if not remaining:\n                    break\n            assert (\n                not remaining\n            ), f\"Return layers ({remaining}) are not present in model\"\n        self.update(layers)\n        self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)\n\n    def forward(self, x):\n        for name, module in self.items():\n            x = module(x)\n        out = self.hooks.get_output(x.device)\n        return out if self.out_as_dict else list(out.values())\n"
  },
  {
    "path": "lavis/models/timesformer/helpers.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified model creation / weight loading / state_dict helpers\n\nimport logging, warnings\nimport os\nimport math\nfrom collections import OrderedDict\n\nimport torch\nimport torch.utils.model_zoo as model_zoo\nimport torch.nn.functional as F\n\n\ndef load_state_dict(checkpoint_path, use_ema=False):\n    if checkpoint_path and os.path.isfile(checkpoint_path):\n        checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n        state_dict_key = \"state_dict\"\n        if isinstance(checkpoint, dict):\n            if use_ema and \"state_dict_ema\" in checkpoint:\n                state_dict_key = \"state_dict_ema\"\n        if state_dict_key and state_dict_key in checkpoint:\n            new_state_dict = OrderedDict()\n            for k, v in checkpoint[state_dict_key].items():\n                # strip `module.` prefix\n                name = k[7:] if k.startswith(\"module\") else k\n                new_state_dict[name] = v\n            state_dict = new_state_dict\n        elif \"model_state\" in checkpoint:\n            state_dict_key = \"model_state\"\n            new_state_dict = OrderedDict()\n            for k, v in checkpoint[state_dict_key].items():\n                # strip `model.` prefix\n                name = k[6:] if k.startswith(\"model\") else k\n                new_state_dict[name] = v\n            state_dict = new_state_dict\n        else:\n            state_dict = checkpoint\n        logging.info(\n            \"Loaded {} from checkpoint '{}'\".format(state_dict_key, checkpoint_path)\n        )\n        return state_dict\n    else:\n        logging.error(\"No checkpoint found at '{}'\".format(checkpoint_path))\n        raise FileNotFoundError()\n\n\ndef load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):\n    state_dict = load_state_dict(checkpoint_path, use_ema)\n    model.load_state_dict(state_dict, strict=strict)\n\n\n# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):\n#     resume_epoch = None\n# if os.path.isfile(checkpoint_path):\n#     checkpoint = torch.load(checkpoint_path, map_location='cpu')\n#     if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n#         if log_info:\n#             _logger.info('Restoring model state from checkpoint...')\n#         new_state_dict = OrderedDict()\n#         for k, v in checkpoint['state_dict'].items():\n#             name = k[7:] if k.startswith('module') else k\n#             new_state_dict[name] = v\n#         model.load_state_dict(new_state_dict)\n\n#         if optimizer is not None and 'optimizer' in checkpoint:\n#             if log_info:\n#                 _logger.info('Restoring optimizer state from checkpoint...')\n#             optimizer.load_state_dict(checkpoint['optimizer'])\n\n#         if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:\n#             if log_info:\n#                 _logger.info('Restoring AMP loss scaler state from checkpoint...')\n#             loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])\n\n#         if 'epoch' in checkpoint:\n#             resume_epoch = checkpoint['epoch']\n#             if 'version' in checkpoint and checkpoint['version'] > 1:\n#                 resume_epoch += 1  # start at the next epoch, old checkpoints incremented before save\n\n#         if log_info:\n#             _logger.info(\"Loaded checkpoint '{}' (epoch {})\".format(checkpoint_path, checkpoint['epoch']))\n#     else:\n#         model.load_state_dict(checkpoint)\n#         if log_info:\n#             _logger.info(\"Loaded checkpoint '{}'\".format(checkpoint_path))\n#     return resume_epoch\n# else:\n#     _logger.error(\"No checkpoint found at '{}'\".format(checkpoint_path))\n#     raise FileNotFoundError()\n\n\ndef load_pretrained(\n    model,\n    cfg=None,\n    num_classes=1000,\n    in_chans=3,\n    filter_fn=None,\n    img_size=224,\n    num_frames=8,\n    num_patches=196,\n    attention_type=\"divided_space_time\",\n    pretrained_model=\"\",\n    strict=True,\n):\n    if cfg is None:\n        cfg = getattr(model, \"default_cfg\")\n    if cfg is None or \"url\" not in cfg or not cfg[\"url\"]:\n        logging.warning(\"Pretrained model URL is invalid, using random initialization.\")\n        return\n\n    if len(pretrained_model) == 0:\n        if cfg is None:\n            logging.info(f\"loading from default config {model.default_cfg}.\")\n        state_dict = model_zoo.load_url(cfg[\"url\"], progress=False, map_location=\"cpu\")\n    else:\n        try:\n            state_dict = load_state_dict(pretrained_model)[\"model\"]\n        except:\n            state_dict = load_state_dict(pretrained_model)\n\n    if filter_fn is not None:\n        state_dict = filter_fn(state_dict)\n\n    if in_chans == 1:\n        conv1_name = cfg[\"first_conv\"]\n        logging.info(\n            \"Converting first conv (%s) pretrained weights from 3 to 1 channel\"\n            % conv1_name\n        )\n        conv1_weight = state_dict[conv1_name + \".weight\"]\n        conv1_type = conv1_weight.dtype\n        conv1_weight = conv1_weight.float()\n        O, I, J, K = conv1_weight.shape\n        if I > 3:\n            assert conv1_weight.shape[1] % 3 == 0\n            # For models with space2depth stems\n            conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)\n            conv1_weight = conv1_weight.sum(dim=2, keepdim=False)\n        else:\n            conv1_weight = conv1_weight.sum(dim=1, keepdim=True)\n        conv1_weight = conv1_weight.to(conv1_type)\n        state_dict[conv1_name + \".weight\"] = conv1_weight\n    elif in_chans != 3:\n        conv1_name = cfg[\"first_conv\"]\n        conv1_weight = state_dict[conv1_name + \".weight\"]\n        conv1_type = conv1_weight.dtype\n        conv1_weight = conv1_weight.float()\n        O, I, J, K = conv1_weight.shape\n        if I != 3:\n            logging.warning(\n                \"Deleting first conv (%s) from pretrained weights.\" % conv1_name\n            )\n            del state_dict[conv1_name + \".weight\"]\n            strict = False\n        else:\n            logging.info(\n                \"Repeating first conv (%s) weights in channel dim.\" % conv1_name\n            )\n            repeat = int(math.ceil(in_chans / 3))\n            conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]\n            conv1_weight *= 3 / float(in_chans)\n            conv1_weight = conv1_weight.to(conv1_type)\n            state_dict[conv1_name + \".weight\"] = conv1_weight\n\n    classifier_name = cfg[\"classifier\"]\n    if num_classes == 1000 and cfg[\"num_classes\"] == 1001:\n        # special case for imagenet trained models with extra background class in pretrained weights\n        classifier_weight = state_dict[classifier_name + \".weight\"]\n        state_dict[classifier_name + \".weight\"] = classifier_weight[1:]\n        classifier_bias = state_dict[classifier_name + \".bias\"]\n        state_dict[classifier_name + \".bias\"] = classifier_bias[1:]\n    elif num_classes != state_dict[classifier_name + \".weight\"].size(0):\n        # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)\n        # completely discard fully connected for all other differences between pretrained and created model\n        del state_dict[classifier_name + \".weight\"]\n        del state_dict[classifier_name + \".bias\"]\n        strict = False\n\n    ## Resizing the positional embeddings in case they don't match\n    logging.info(\n        f\"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}\"\n    )\n    if num_patches + 1 != state_dict[\"pos_embed\"].size(1):\n        pos_embed = state_dict[\"pos_embed\"]\n        cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)\n        other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)\n        new_pos_embed = F.interpolate(\n            other_pos_embed, size=(num_patches), mode=\"nearest\"\n        )\n        new_pos_embed = new_pos_embed.transpose(1, 2)\n        new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n        state_dict[\"pos_embed\"] = new_pos_embed\n\n    ## Resizing time embeddings in case they don't match\n    if \"time_embed\" in state_dict and num_frames != state_dict[\"time_embed\"].size(1):\n        logging.info(\n            f\"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}\"\n        )\n        time_embed = state_dict[\"time_embed\"].transpose(1, 2)\n        new_time_embed = F.interpolate(time_embed, size=(num_frames), mode=\"nearest\")\n        state_dict[\"time_embed\"] = new_time_embed.transpose(1, 2)\n\n    ## Initializing temporal attention\n    if attention_type == \"divided_space_time\":\n        new_state_dict = state_dict.copy()\n        for key in state_dict:\n            if \"blocks\" in key and \"attn\" in key:\n                new_key = key.replace(\"attn\", \"temporal_attn\")\n                if not new_key in state_dict:\n                    new_state_dict[new_key] = state_dict[key]\n                else:\n                    new_state_dict[new_key] = state_dict[new_key]\n            if \"blocks\" in key and \"norm1\" in key:\n                new_key = key.replace(\"norm1\", \"temporal_norm1\")\n                if not new_key in state_dict:\n                    new_state_dict[new_key] = state_dict[key]\n                else:\n                    new_state_dict[new_key] = state_dict[new_key]\n        state_dict = new_state_dict\n\n    ## Loading the weights\n    model.load_state_dict(state_dict, strict=False)\n\n\ndef load_pretrained_imagenet(\n    model,\n    pretrained_model,\n    cfg=None,\n    ignore_classifier=True,\n    num_frames=8,\n    num_patches=196,\n    **kwargs,\n):\n    import timm\n\n    logging.info(f\"Loading vit_base_patch16_224 checkpoints.\")\n    loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(\n        pretrained=True\n    ).state_dict()\n\n    del loaded_state_dict[\"head.weight\"]\n    del loaded_state_dict[\"head.bias\"]\n\n    ## Initializing temporal attention\n    new_state_dict = loaded_state_dict.copy()\n    for key in loaded_state_dict:\n        if \"blocks\" in key and \"attn\" in key:\n            new_key = key.replace(\"attn\", \"temporal_attn\")\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n        if \"blocks\" in key and \"norm1\" in key:\n            new_key = key.replace(\"norm1\", \"temporal_norm1\")\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n\n    loaded_state_dict = new_state_dict\n\n    loaded_keys = loaded_state_dict.keys()\n    model_keys = model.state_dict().keys()\n\n    load_not_in_model = [k for k in loaded_keys if k not in model_keys]\n    model_not_in_load = [k for k in model_keys if k not in loaded_keys]\n\n    toload = dict()\n    mismatched_shape_keys = []\n    for k in model_keys:\n        if k in loaded_keys:\n            if model.state_dict()[k].shape != loaded_state_dict[k].shape:\n                mismatched_shape_keys.append(k)\n            else:\n                toload[k] = loaded_state_dict[k]\n\n    logging.info(\"Keys in loaded but not in model:\")\n    logging.info(f\"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}\")\n    logging.info(\"Keys in model but not in loaded:\")\n    logging.info(f\"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}\")\n    logging.info(\"Keys in model and loaded, but shape mismatched:\")\n    logging.info(\n        f\"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}\"\n    )\n\n    model.load_state_dict(toload, strict=False)\n\n\ndef load_pretrained_kinetics(\n    model,\n    pretrained_model,\n    cfg=None,\n    ignore_classifier=True,\n    num_frames=8,\n    num_patches=196,\n    **kwargs,\n):\n    if cfg is None:\n        cfg = getattr(model, \"default_cfg\")\n    if cfg is None or \"url\" not in cfg or not cfg[\"url\"]:\n        logging.warning(\"Pretrained model URL is invalid, using random initialization.\")\n        return\n\n    assert (\n        len(pretrained_model) > 0\n    ), \"Path to pre-trained Kinetics weights not provided.\"\n\n    state_dict = load_state_dict(pretrained_model)\n\n    classifier_name = cfg[\"classifier\"]\n    if ignore_classifier:\n\n        classifier_weight_key = classifier_name + \".weight\"\n        classifier_bias_key = classifier_name + \".bias\"\n\n        state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key]\n        state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key]\n\n    else:\n        raise NotImplementedError(\n            \"[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier.\"\n        )\n\n    ## Resizing the positional embeddings in case they don't match\n    if num_patches + 1 != state_dict[\"pos_embed\"].size(1):\n        new_pos_embed = resize_spatial_embedding(state_dict, \"pos_embed\", num_patches)\n        state_dict[\"pos_embed\"] = new_pos_embed\n\n    ## Resizing time embeddings in case they don't match\n    if \"time_embed\" in state_dict and num_frames != state_dict[\"time_embed\"].size(1):\n        state_dict[\"time_embed\"] = resize_temporal_embedding(\n            state_dict, \"time_embed\", num_frames\n        )\n\n    ## Loading the weights\n    try:\n        model.load_state_dict(state_dict, strict=True)\n        logging.info(\"Succeeded in loading Kinetics pre-trained weights.\")\n    except:\n        logging.error(\"Error in loading Kinetics pre-trained weights.\")\n\n\ndef resize_spatial_embedding(state_dict, key, num_patches):\n    logging.info(\n        f\"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}\"\n    )\n\n    pos_embed = state_dict[key]\n\n    cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)\n    other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)\n\n    new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode=\"nearest\")\n    new_pos_embed = new_pos_embed.transpose(1, 2)\n    new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n\n    return new_pos_embed\n\n\ndef resize_temporal_embedding(state_dict, key, num_frames):\n    logging.info(\n        f\"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}\"\n    )\n\n    time_embed = state_dict[key].transpose(1, 2)\n    new_time_embed = F.interpolate(time_embed, size=(num_frames), mode=\"nearest\")\n\n    return new_time_embed.transpose(1, 2)\n\n\ndef detach_variable(inputs):\n    if isinstance(inputs, tuple):\n        out = []\n        for inp in inputs:\n            x = inp.detach()\n            x.requires_grad = inp.requires_grad\n            out.append(x)\n        return tuple(out)\n    else:\n        raise RuntimeError(\n            \"Only tuple of tensors is supported. Got Unsupported input type: \",\n            type(inputs).__name__,\n        )\n\n\ndef check_backward_validity(inputs):\n    if not any(inp.requires_grad for inp in inputs):\n        warnings.warn(\n            \"None of the inputs have requires_grad=True. Gradients will be None\"\n        )\n"
  },
  {
    "path": "lavis/models/timesformer/linear.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n\"\"\" Linear layer (alternate definition)\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\n\nclass Linear(nn.Linear):\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if torch.jit.is_scripting():\n            bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None\n            return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)\n        else:\n            return F.linear(input, self.weight, self.bias)\n"
  },
  {
    "path": "lavis/models/timesformer/vit.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified Model definition\n\nimport logging\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils\nimport torch.utils.checkpoint\nfrom einops import rearrange\nfrom fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper\n\nfrom .helpers import load_pretrained, load_pretrained_imagenet, load_pretrained_kinetics\nfrom .vit_utils import (\n    IMAGENET_DEFAULT_MEAN,\n    IMAGENET_DEFAULT_STD,\n    DropPath,\n    to_2tuple,\n    trunc_normal_,\n)\n\n\ndef _cfg(url=\"\", **kwargs):\n    return {\n        \"url\": url,\n        \"num_classes\": 1000,\n        \"input_size\": (3, 224, 224),\n        \"pool_size\": None,\n        \"crop_pct\": 0.9,\n        \"interpolation\": \"bicubic\",\n        \"mean\": IMAGENET_DEFAULT_MEAN,\n        \"std\": IMAGENET_DEFAULT_STD,\n        \"first_conv\": \"patch_embed.proj\",\n        \"classifier\": \"head\",\n        **kwargs,\n    }\n\n\ndefault_cfgs = {\n    \"vit_base_patch16_224\": _cfg(\n        url=\"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth\",\n        mean=(0.5, 0.5, 0.5),\n        std=(0.5, 0.5, 0.5),\n    ),\n}\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        qkv_bias=False,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n        with_qkv=True,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n        self.with_qkv = with_qkv\n        if self.with_qkv:\n            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n            self.proj = nn.Linear(dim, dim)\n            self.proj_drop = nn.Dropout(proj_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        if self.with_qkv:\n            qkv = (\n                self.qkv(x)\n                .reshape(B, N, 3, self.num_heads, C // self.num_heads)\n                .permute(2, 0, 3, 1, 4)\n            )\n            q, k, v = qkv[0], qkv[1], qkv[2]\n        else:\n            qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(\n                0, 2, 1, 3\n            )\n            q, k, v = qkv, qkv, qkv\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        if self.with_qkv:\n            x = self.proj(x)\n            x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        layer_num,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.1,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n        attention_type=\"divided_space_time\",\n        use_grad_checkpointing=False,\n    ):\n        super().__init__()\n        self.attention_type = attention_type\n        assert attention_type in [\n            \"divided_space_time\",\n            \"space_only\",\n            \"joint_space_time\",\n        ]\n\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        # Temporal Attention Parameters\n        if self.attention_type == \"divided_space_time\":\n            self.temporal_norm1 = norm_layer(dim)\n            self.temporal_attn = Attention(\n                dim,\n                num_heads=num_heads,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                attn_drop=attn_drop,\n                proj_drop=drop,\n            )\n            self.temporal_fc = nn.Linear(dim, dim)\n\n        # drop path\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n        # [dxli]\n        self.layer_num = layer_num\n        self.use_grad_checkpointing = use_grad_checkpointing\n\n        if use_grad_checkpointing:\n            self.temporal_attn = checkpoint_wrapper(self.temporal_attn)\n            self.attn = checkpoint_wrapper(self.attn)\n            self.mlp = checkpoint_wrapper(self.mlp)\n\n    def forward(self, x, B, T, W):\n        num_spatial_tokens = (x.size(1) - 1) // T\n        H = num_spatial_tokens // W\n\n        if self.attention_type in [\"space_only\", \"joint_space_time\"]:\n            x = x + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n            return x\n        elif self.attention_type == \"divided_space_time\":\n            # Temporal\n            xt = x[:, 1:, :]\n            xt = rearrange(xt, \"b (h w t) m -> (b h w) t m\", b=B, h=H, w=W, t=T)\n\n            temporal_attn_out = self.temporal_attn(self.temporal_norm1(xt))\n\n            res_temporal = self.drop_path(temporal_attn_out)\n\n            res_temporal = rearrange(\n                res_temporal, \"(b h w) t m -> b (h w t) m\", b=B, h=H, w=W, t=T\n            )\n            res_temporal = self.temporal_fc(res_temporal)\n            xt = x[:, 1:, :] + res_temporal\n\n            # Spatial\n            init_cls_token = x[:, 0, :].unsqueeze(1)\n            cls_token = init_cls_token.repeat(1, T, 1)\n            cls_token = rearrange(cls_token, \"b t m -> (b t) m\", b=B, t=T).unsqueeze(1)\n            xs = xt\n            xs = rearrange(xs, \"b (h w t) m -> (b t) (h w) m\", b=B, h=H, w=W, t=T)\n            xs = torch.cat((cls_token, xs), 1)\n\n            spatial_attn_out = self.attn(self.norm1(xs))\n            res_spatial = self.drop_path(spatial_attn_out)\n\n            # Taking care of CLS token\n            cls_token = res_spatial[:, 0, :]\n            cls_token = rearrange(cls_token, \"(b t) m -> b t m\", b=B, t=T)\n            # averaging for every frame\n            cls_token = torch.mean(cls_token, 1, True)\n            res_spatial = res_spatial[:, 1:, :]\n            res_spatial = rearrange(\n                res_spatial, \"(b t) (h w) m -> b (h w t) m\", b=B, h=H, w=W, t=T\n            )\n            res = res_spatial\n            x = xt\n\n            # Mlp\n            x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1)\n\n            x_res = x\n\n            x = self.norm2(x)\n            # x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n            # MLP\n            mlp_out = self.mlp(x)\n\n            x = x_res + self.drop_path(mlp_out)\n            return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Image to Patch Embedding\"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n\n    def forward(self, x):\n        B, C, T, H, W = x.shape\n        x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        x = self.proj(x)\n        W = x.size(-1)\n        x = x.flatten(2).transpose(1, 2)\n        return x, T, W\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\"Vision Transformere\"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.1,\n        hybrid_backbone=None,\n        norm_layer=nn.LayerNorm,\n        num_frames=8,\n        attention_type=\"divided_space_time\",\n        dropout=0.0,\n        use_grad_checkpointing=False,\n        ckpt_layer=0,\n    ):\n        super().__init__()\n\n        self.attention_type = attention_type\n        self.depth = depth\n        self.dropout = nn.Dropout(dropout)\n        self.num_classes = num_classes\n        # num_features for consistency with other models\n        self.num_features = self.embed_dim = embed_dim\n        self.patch_embed = PatchEmbed(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n        )\n        num_patches = self.patch_embed.num_patches\n\n        # Positional Embeddings\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        if self.attention_type != \"space_only\":\n            self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))\n            self.time_drop = nn.Dropout(p=drop_rate)\n\n        # Attention Blocks\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, self.depth)\n        ]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList(\n            [\n                Block(\n                    layer_num=i,\n                    use_grad_checkpointing=(\n                        use_grad_checkpointing and i >= self.depth - ckpt_layer\n                    ),\n                    dim=embed_dim,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    attention_type=self.attention_type,\n                )\n                for i in range(self.depth)\n            ]\n        )\n        self.norm = norm_layer(embed_dim)\n\n        # Classifier head\n        self.head = (\n            nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        trunc_normal_(self.pos_embed, std=0.02)\n        trunc_normal_(self.cls_token, std=0.02)\n        self.apply(self._init_weights)\n\n        # initialization of temporal attention weights\n        if self.attention_type == \"divided_space_time\":\n            i = 0\n            for m in self.blocks.modules():\n                m_str = str(m)\n                if \"Block\" in m_str:\n                    if i > 0:\n                        nn.init.constant_(m.temporal_fc.weight, 0)\n                        nn.init.constant_(m.temporal_fc.bias, 0)\n                    i += 1\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\"pos_embed\", \"cls_token\", \"time_embed\"}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=\"\"):\n        self.num_classes = num_classes\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def remove_classifier(self):\n        self.num_classes = 0\n        self.head = None\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x, T, W = self.patch_embed(x)\n        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        # resizing the positional embeddings in case they don't match the input at inference\n        if x.size(1) != self.pos_embed.size(1):\n            pos_embed = self.pos_embed\n            cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)\n            other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)\n            P = int(other_pos_embed.size(2) ** 0.5)\n            H = x.size(1) // W\n            other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P)\n            new_pos_embed = F.interpolate(other_pos_embed, size=(H, W), mode=\"nearest\")\n            new_pos_embed = new_pos_embed.flatten(2)\n            new_pos_embed = new_pos_embed.transpose(1, 2)\n            new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n            x = x + new_pos_embed\n        else:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        # Time Embeddings\n        if self.attention_type != \"space_only\":\n            cls_tokens = x[:B, 0, :].unsqueeze(1)\n            x = x[:, 1:]\n            x = rearrange(x, \"(b t) n m -> (b n) t m\", b=B, t=T)\n            # Resizing time embeddings in case they don't match\n            if T != self.time_embed.size(1):\n                time_embed = self.time_embed.transpose(1, 2)\n                new_time_embed = F.interpolate(time_embed, size=(T), mode=\"nearest\")\n                new_time_embed = new_time_embed.transpose(1, 2)\n                x = x + new_time_embed\n            else:\n                x = x + self.time_embed\n            x = self.time_drop(x)\n            x = rearrange(x, \"(b n) t m -> b (n t) m\", b=B, t=T)\n            x = torch.cat((cls_tokens, x), dim=1)\n\n        # Attention blocks\n        for blk in self.blocks:\n            x = blk(x, B, T, W)\n\n        # Predictions for space-only baseline\n        if self.attention_type == \"space_only\":\n            x = rearrange(x, \"(b t) n m -> b t n m\", b=B, t=T)\n            x = torch.mean(x, 1)  # averaging predictions for every frame\n\n        x = self.norm(x)\n\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef _conv_filter(state_dict, patch_size=16):\n    \"\"\"convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    for k, v in state_dict.items():\n        if \"patch_embed.proj.weight\" in k:\n            if v.shape[-1] != patch_size:\n                patch_size = v.shape[-1]\n            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n        out_dict[k] = v\n    return out_dict\n\n\nclass vit_base_patch16_224(nn.Module):\n    def __init__(self, cfg, **kwargs):\n        super(vit_base_patch16_224, self).__init__()\n        self.pretrained = True\n        patch_size = 16\n        self.model = VisionTransformer(\n            img_size=cfg.DATA.TRAIN_CROP_SIZE,\n            num_classes=cfg.MODEL.NUM_CLASSES,\n            patch_size=patch_size,\n            embed_dim=768,\n            depth=12,\n            num_heads=12,\n            mlp_ratio=4,\n            qkv_bias=True,\n            norm_layer=partial(nn.LayerNorm, eps=1e-6),\n            drop_rate=0.0,\n            attn_drop_rate=0.0,\n            drop_path_rate=0.1,\n            num_frames=cfg.DATA.NUM_FRAMES,\n            attention_type=cfg.TIMESFORMER.ATTENTION_TYPE,\n            **kwargs,\n        )\n\n        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE\n        self.model.default_cfg = default_cfgs[\"vit_base_patch16_224\"]\n        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * (\n            cfg.DATA.TRAIN_CROP_SIZE // patch_size\n        )\n        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL\n        if self.pretrained:\n            load_pretrained(\n                self.model,\n                num_classes=self.model.num_classes,\n                in_chans=kwargs.get(\"in_chans\", 3),\n                filter_fn=_conv_filter,\n                img_size=cfg.DATA.TRAIN_CROP_SIZE,\n                num_patches=self.num_patches,\n                attention_type=self.attention_type,\n                pretrained_model=pretrained_model,\n            )\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\n\nclass TimeSformer(nn.Module):\n    def __init__(\n        self,\n        image_size=224,\n        patch_size=16,\n        n_frms=8,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.1,\n        drop_rate=0,\n        use_grad_ckpt=False,\n        ckpt_layer=0,\n        remove_classifier=True,\n        **kwargs,\n    ):\n        super(TimeSformer, self).__init__()\n\n        self.img_size = image_size\n        self.patch_size = patch_size\n        self.num_frames = n_frms\n        self.attn_drop_rate = attn_drop_rate\n        self.drop_path_rate = drop_path_rate\n        self.drop_rate = drop_rate\n        self.use_grad_ckpt = use_grad_ckpt\n        self.ckpt_layer = ckpt_layer\n\n        self.attention_type = \"divided_space_time\"\n\n        logging.info(\n            f\"Initializing TimeSformer with img_size={self.img_size}, patch_size={self.patch_size}, num_frames={self.num_frames}\"\n        )\n\n        # will be ignored when loading official pretrained ckpt\n        self.num_classes = 400\n\n        self.model = VisionTransformer(\n            img_size=self.img_size,\n            num_classes=self.num_classes,\n            patch_size=self.patch_size,\n            embed_dim=768,\n            depth=12,\n            num_heads=12,\n            mlp_ratio=4,\n            qkv_bias=True,\n            norm_layer=partial(nn.LayerNorm, eps=1e-6),\n            drop_rate=self.drop_rate,\n            attn_drop_rate=self.attn_drop_rate,\n            drop_path_rate=self.drop_path_rate,\n            num_frames=self.num_frames,\n            attention_type=self.attention_type,\n            use_grad_checkpointing=self.use_grad_ckpt,\n            ckpt_layer=self.ckpt_layer,\n            **kwargs,\n        )\n\n        if remove_classifier:\n            self.model.remove_classifier()\n\n        self.model.default_cfg = default_cfgs[\n            \"vit_base_patch\" + str(self.patch_size) + \"_224\"\n        ]\n        self.num_patches = (self.img_size // self.patch_size) * (\n            self.img_size // self.patch_size\n        )\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\n    def forward_features(self, x):\n        # b, c, t, h, w = x.shape\n        x = self.model.forward_features(x)\n\n        ## apply pooling\n        W = H = self.img_size // self.patch_size\n        T = self.num_frames\n\n        cls_tokens = x[:, 0, :].unsqueeze(1)\n        other_tokens = x[:, 1:, :]\n\n        x = rearrange(other_tokens, \"b (h w t) m -> b t (h w) m\", h=H, w=W, t=T)\n\n        x = torch.mean(x, dim=1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        return x\n\n    def load_state_dict(self, pretrained_ckpt_path):\n        logging.info(\n            \"Loading TimeSformer checkpoints from {}\".format(pretrained_ckpt_path)\n        )\n\n        if pretrained_ckpt_path == \"vit_base_patch16_224\":\n            load_ckpt_func = load_pretrained_imagenet\n        else:\n            load_ckpt_func = load_pretrained_kinetics\n\n        load_ckpt_func(\n            self.model,\n            num_classes=self.model.num_classes,\n            in_chans=3,\n            filter_fn=_conv_filter,\n            img_size=self.img_size,\n            num_frames=self.num_frames,\n            num_patches=self.num_patches,\n            attention_type=self.attention_type,\n            pretrained_model=pretrained_ckpt_path,\n        )\n"
  },
  {
    "path": "lavis/models/timesformer/vit_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n Based on https://github.com/facebookresearch/TimeSformer\n\"\"\"\n\n# Copyright 2020 Ross Wightman\n# Various utility functions\n\nimport torch\nimport torch.nn as nn\nimport math\nimport warnings\nimport torch.nn.functional as F\n\nfrom itertools import repeat\nimport collections.abc as container_abcs\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (0.0167 * 255)] * 3)\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.0))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, container_abcs.Iterable):\n            return x\n        return tuple(repeat(x, n))\n\n    return parse\n\n\nto_2tuple = _ntuple(2)\n\n# Calculate symmetric padding for a convolution\ndef get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:\n    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2\n    return padding\n\n\ndef get_padding_value(padding, kernel_size, **kwargs):\n    dynamic = False\n    if isinstance(padding, str):\n        # for any string padding, the padding will be calculated for you, one of three ways\n        padding = padding.lower()\n        if padding == \"same\":\n            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact\n            if is_static_pad(kernel_size, **kwargs):\n                # static case, no extra overhead\n                padding = get_padding(kernel_size, **kwargs)\n            else:\n                # dynamic 'SAME' padding, has runtime/GPU memory overhead\n                padding = 0\n                dynamic = True\n        elif padding == \"valid\":\n            # 'VALID' padding, same as padding=0\n            padding = 0\n        else:\n            # Default to PyTorch style 'same'-ish symmetric padding\n            padding = get_padding(kernel_size, **kwargs)\n    return padding, dynamic\n\n\n# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution\ndef get_same_padding(x: int, k: int, s: int, d: int):\n    return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0)\n\n\n# Can SAME padding for given args be done statically?\ndef is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):\n    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0\n\n\n# Dynamically pad input x with 'SAME' padding for conv with specified args\n# def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):\ndef pad_same(x, k, s, d=(1, 1), value=0):\n    ih, iw = x.size()[-2:]\n    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(\n        iw, k[1], s[1], d[1]\n    )\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(\n            x,\n            [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],\n            value=value,\n        )\n    return x\n\n\ndef adaptive_pool_feat_mult(pool_type=\"avg\"):\n    if pool_type == \"catavgmax\":\n        return 2\n    else:\n        return 1\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (\n        x.ndim - 1\n    )  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "lavis/models/topk.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nDETR model and criterion classes.\n\"\"\"\n\nimport math\nimport torch\nimport copy\nimport einops\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom dataclasses import dataclass\nfrom typing import Optional\nfrom enum import IntEnum\nfrom einops import rearrange\n\nclass PerturbedTopK(nn.Module):\n    def __init__(self, k: int, num_samples: int = 1000):\n        super(PerturbedTopK, self).__init__()\n        self.num_samples = num_samples\n        self.k = k\n\n    def __call__(self, x, sigma):\n        return PerturbedTopKFunction.apply(x, self.k, self.num_samples, sigma)\n\n\nclass PerturbedTopKFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, k: int, num_samples: int = 1000, sigma: float = 0.05):\n        #print('x', x.shape)\n        b, d = x.shape\n        # for Gaussian: noise and gradient are the same.\n        noise = torch.normal(mean=0.0, std=1.0, size=(b, num_samples, d)).to(x.device)\n        perturbed_x = x[:, None, :] + noise * sigma # b, nS, d\n        #print('perturbed_x', perturbed_x.shape)\n        topk_results = torch.topk(perturbed_x, k=k, dim=-1, sorted=False)\n        #print('topk_results',topk_results)\n\n        indices = topk_results.indices # b, nS, k\n        indices = torch.sort(indices, dim=-1).values # b, nS, k\n        # print('indices', indices.shape ,indices[0,0,0])\n\n        perturbed_output = torch.nn.functional.one_hot(indices, num_classes=d).float()\n        indicators = perturbed_output.mean(dim=1) # b, k, d\n        # print('perturbed_output', perturbed_output.shape, perturbed_output[0,indices[0,0,0],0,0])\n\n        # constants for backward\n        ctx.k = k\n        ctx.num_samples = num_samples\n        ctx.sigma = sigma\n\n        # tensors for backward\n        ctx.perturbed_output = perturbed_output\n        ctx.noise = noise\n        return indicators\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        if grad_output is None:\n            return tuple([None] * 5)\n\n        noise_gradient = ctx.noise\n        if ctx.sigma <= 1e-20:\n            b, _, k, d = ctx.perturbed_output.size()\n            expected_gradient = torch.zeros(b, k, d).to(grad_output.device)\n        else:\n            expected_gradient = (\n                torch.einsum(\"bnkd,bnd->bkd\", ctx.perturbed_output, noise_gradient)\n                / ctx.num_samples\n                / (ctx.sigma)\n            )\n\n        grad_input = torch.einsum(\"bkd,bkd->bd\", grad_output, expected_gradient)\n\n        return (grad_input,) + tuple([None] * 5)\n\ndef HardTopK(k, x):\n    topk_results = torch.topk(x, k=k, dim=-1, sorted=False)\n    indices = topk_results.indices # b, k\n    indices = torch.sort(indices, dim=-1).values\n    return indices\n\n\ndef batched_index_select(input, dim, index):\n    for i in range(1, len(input.shape)):\n        if i != dim:\n            index = index.unsqueeze(i)\n    expanse = list(input.shape)\n    expanse[0] = -1\n    expanse[dim] = -1\n    index = index.expand(expanse)\n    return torch.gather(input, dim, index)\n\ndef extract_frames_from_indices(x, indices):\n    batch_size, _, n, channels = x.shape\n    k = indices.shape[-1]\n    all_frame = x\n    frames = batched_index_select(all_frame, 1, indices)\n    frames = frames.contiguous().view(batch_size, k, n, channels)\n    return frames\n\n\ndef extract_frames_from_indicators(x, indicators):\n    indicators = rearrange(indicators, \"b d k -> b k d\")\n    frames = torch.einsum(\"b k d, b d n c-> b k n c\",\n                         indicators, x)\n    return frames\n\n\nclass ModalityEmbeddingsID(IntEnum):\n    TEXT_QUESTION = 0\n    TEXT_EMBEDDING = 1\n    TEXT_UNUSED = 2  # ignore\n    VISUAL_EMBEDDING = 3\n    VISUAL_UNUSED = 4  # ignore\n\nclass ModalityEmbeddings(nn.Module):\n    \"\"\"\n    Provides embeddings that indicate type of modality; for use with multimodal inputs for ATP. See atp.py for usage.\n    \"\"\"\n    def __init__(self,\n                 d_model: int,\n                 use_text_query: bool = False,\n                 use_text_cands: bool = False,\n                 n_cands: int = 5):\n        \"\"\"\n        Details for each of these arguments are provided in ATPConfig.\n        \"\"\"\n        super().__init__()\n        self.d_model = d_model\n        self.embedding = nn.Embedding(num_embeddings=len(ModalityEmbeddingsID),\n                                      embedding_dim=d_model)\n\n        self.use_text_query = use_text_query\n        self.use_text_cands = use_text_cands\n        self.n_cands = n_cands if use_text_cands else 0\n        self.n_text_feats = 1 if use_text_query else 0\n        if use_text_cands:\n            self.n_text_feats += n_cands\n\n    def forward(self, x, num_frame):\n        \"\"\"\n        x: torch.tensor of size (L, N, D)\n        returns modality embeddings for x of size (L, *, D)\n        \"\"\"\n        L, N, D = x.size()  # (sequence_length, batch_size, feature_dim)\n        num_txt = L - num_frame\n        \n        # assemble the IDs for the modality encodings, language inputs then vision inputs\n        class_ids = []\n        if self.use_text_query:\n            class_ids.extend([ModalityEmbeddingsID.TEXT_QUESTION,] * num_txt)\n        # if self.use_text_cands:\n        #     class_ids.extend([ModalityEmbeddingsID.TEXT_EMBEDDING,] * self.n_cands)\n        class_ids.extend([ModalityEmbeddingsID.VISUAL_EMBEDDING,] * num_frame)\n        \n        class_ids = torch.tensor(\n            class_ids,\n            dtype=torch.long,\n            device=x.device\n        ).unsqueeze(-1)\n        \n        # return modality embeddings\n        return self.embedding(class_ids)\n\n@dataclass\nclass ATPConfig:\n    '''\n    ATPConfig contains the parameters needed for the ATPSelectorModel (and its ATPEncoder).\n    '''\n    # ATPEncoder params\n    n_layers: int = 6\n    n_heads: int = 4\n    d_model: int = 256\n    d_input_t: int = 2048\n    d_input_v: int = 1408\n    d_model_ff: int = 256\n    enc_dropout: float = 0.1\n    use_text_query: bool = True  # at least one use_text_* needs to be true for ATP to be multimodal\n    use_text_cands: bool = False  # ^ see above. (note: if both are false, ATP is vision-only)\n    n_cands: int = 5  # only relevant when use_text_cands is set to true\n    # ATPSelector params\n    use_ste: bool = True  # controls type of selector during ATP training; see ATPSelectorModel.forward\n    sel_dropout: float = 0.0\n    d_input: int = 512  # size of the input vision-language embeddings (e.g. CLIP-ViT-B32 is size 512)\n    \n    def default_args(cls):\n        return cls(n_layers = 6,\n                   n_heads = 4,\n                   d_model = 256,\n                   d_input_t = 2048,\n                   d_input_v = 1408,\n                   d_model_ff = 256,\n                   enc_dropout = 0.1,\n                   use_text_query = True,\n                   use_text_cands = False,\n                   n_cands = 5,\n                   use_ste = True,\n                   sel_dropout = 0.0,\n                   d_input = 512)\n\n    @classmethod\n    def from_args(cls, args):\n        return cls(n_layers = args.n_layers,\n                   n_heads = args.n_heads,\n                   d_model = args.d_model,\n                   d_model_ff = args.d_model_ff,\n                   enc_dropout = args.enc_dropout,\n                   use_text_query = args.use_text_query,\n                   use_text_cands = args.use_text_cands,\n                   n_cands = args.n_cands,\n                   use_ste = args.use_ste,\n                   sel_dropout = args.sel_dropout,\n                   d_input = args.d_input)\n\nclass ATPEncoder(nn.Module):\n    \"\"\"\n    The multimodal transformer encoder for the ATP model. For analysis purposes, the ATP encoder\n    does not use any positional information (no positional encodings + transformer / self-attention)\n    and is generally kept low-capacity. If the goal is raw accuracy (not analysis), you can relax these constraints.\n    \"\"\"\n    def __init__(self, config: ATPConfig):\n        \"\"\"\n        config: ATPConfig with parameters for the (transformer-based, atemporal) encoder for ATP.\n        See ATPConfig documentation for details.\n        \"\"\"\n        super().__init__()\n        self.d_model = config.d_model\n\n        self.dropout = nn.Dropout(p=config.enc_dropout)\n\n\n        self.modality_encoding = ModalityEmbeddings(d_model=self.d_model,\n                                                    use_text_query=config.use_text_query,\n                                                    use_text_cands=config.use_text_cands,\n                                                    n_cands=config.n_cands)\n        \n        atp_encoder_layer = nn.TransformerEncoderLayer(\n            d_model=self.d_model,\n            nhead=config.n_heads,\n            dim_feedforward=config.d_model_ff,\n            dropout=config.enc_dropout,\n            activation='relu'\n        )\n\n        self.transformer_encoder = nn.TransformerEncoder(atp_encoder_layer, config.n_layers)\n\n    def forward(self, x_inputs: torch.tensor, vis_L):\n        \"\"\"\n        x_inputs: torch.tensor of shape (L, N, D)\n        \"\"\"\n        L, N, D = x_inputs.size()  # (sequence_length, batch_size, d_model)\n        assert D == self.d_model, \"inputs dimension mismatch\"\n        x_encoded = x_inputs * math.sqrt(self.d_model)\n        x_encoded += self.modality_encoding(x_encoded, vis_L)\n        x_encoded = self.dropout(x_encoded)\n        x_encoded = self.transformer_encoder(x_encoded)\n\n        return x_encoded\n\nclass TopK_Selector(nn.Module):\n    \"\"\"\n    The Atemporal Probe (ATP) selector model. Takes as input a sequence of image-language \n    encoding and outputs a (discrete) selection over the input frames, to help analyze \n    downstream discriminative video-language tasks.\n    \"\"\"\n    \n    def __init__(self, config=ATPConfig, num_select=4):\n        \"\"\"\n        config: ATPConfig with parameters for initializing the ATPSelectorModel (and its encoder).\n        See ATPConfig documentation for details.\n        \"\"\"\n        super().__init__()\n        self.config = config\n        self.t_embedding = nn.Linear(config.d_input_t, config.d_input)\n        self.v_embedding = nn.Linear(config.d_input_v, config.d_input)\n        self.embedding = nn.Linear(config.d_input, config.d_model)\n        self.atp_encoder = ATPEncoder(config)\n        self.dropout = nn.Dropout(p=config.sel_dropout)\n        self.logits = nn.Linear(config.d_model, 1)\n        self.num_select = num_select\n        self.sigma = 0.1\n\n    def forward(self,\n                x_vis, # [b, t, d]\n                x_txt, # [b, n, d]\n                **kwargs):\n        \"\"\"\n        \"\"\"\n        x_vis_cls = x_vis[:, :, 0, :] # b t n c\n        N, vis_L, D = x_vis_cls.size()  # (batch_size, sequence_length, feature_dimension)\n        # embed the input sequence to the (smaller) model dimension (d_model) with modality encodings.\n        x_vis_cls = self.v_embedding(self.dropout(x_vis_cls))\n        x_txt = self.t_embedding(self.dropout(x_txt))\n        x_inputs = []\n        x_vis_cls = x_vis_cls.permute(1, 0, 2)\n        x_inputs.append(x_txt.permute(1,0,2)) # (n, b, d)\n        x_inputs.append(x_vis_cls)\n        x_inputs = torch.cat(x_inputs, dim=0)\n        x_encoded = self.embedding(self.dropout(x_inputs))\n        x_atp_encoded = self.atp_encoder(x_encoded, vis_L)\n        x_atp_encoded = x_atp_encoded.permute(1, 0, 2)\n        x_encoded_v = x_atp_encoded[:, -vis_L: , :]\n        # obtain selection scores (logits)\n        x_logits = self.logits(self.dropout(x_encoded_v)).squeeze()\n        #print('x_logits', x_logits.shape)\n\n        if self.training:\n            indices = PerturbedTopKFunction.apply(x_logits, self.num_select)\n            #print('indices', indices.shape)\n            indices = einops.rearrange(indices, \"b k d -> b d k\")\n\n            if indices is not None:\n                qa_frames = extract_frames_from_indicators(x_vis, indices)\n            else:\n                raise RuntimeError(\"Empty indices!\")\n        else:\n            indices = HardTopK(self.num_select, x_logits)\n            if indices is not None:\n                qa_frames = extract_frames_from_indices(x_vis, indices)\n            else:\n                raise RuntimeError(\"Empty indices!\")\n\n\n        return qa_frames\n\nif __name__ == \"__main__\":\n    selector_config = ATPConfig.default_args\n\n    Selector = TopK_Selector(num_select=4) #.eval()\n\n    x_vis = torch.rand([2, 8, 257, 1408])\n    x_txt = torch.rand([2, 68, 2048])\n\n    out = Selector(x_vis, x_txt)\n    print(out.shape)\n\n\n"
  },
  {
    "path": "lavis/models/vit.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n \n Based on timm code base\n https://github.com/rwightman/pytorch-image-models/tree/master/timm\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom timm.models.vision_transformer import _cfg, PatchEmbed\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_, DropPath\nfrom timm.models.helpers import named_apply, adapt_input_conv\n\nfrom fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper\nfrom lavis.models.base_model import BaseEncoder\n\n\nclass Mlp(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        qkv_bias=False,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim**-0.5\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.attn_gradients = None\n        self.attention_map = None\n\n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n\n    def get_attn_gradients(self):\n        return self.attn_gradients\n\n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n\n    def get_attention_map(self):\n        return self.attention_map\n\n    def forward(self, x, register_hook=False):\n        B, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = (\n            qkv[0],\n            qkv[1],\n            qkv[2],\n        )  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        if register_hook:\n            self.save_attention_map(attn)\n            attn.register_hook(self.save_attn_gradients)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n        use_grad_checkpointing=False,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n        if use_grad_checkpointing:\n            self.attn = checkpoint_wrapper(self.attn)\n            self.mlp = checkpoint_wrapper(self.mlp)\n\n    def forward(self, x, register_hook=False):\n        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\"Vision Transformer\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -\n        https://arxiv.org/abs/2010.11929\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        representation_size=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.0,\n        norm_layer=None,\n        use_grad_checkpointing=False,\n        ckpt_layer=0,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            qk_scale (float): override default qk scale of head_dim ** -0.5 if set\n            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            norm_layer: (nn.Module): normalization layer\n        \"\"\"\n        super().__init__()\n        self.num_features = (\n            self.embed_dim\n        ) = embed_dim  # num_features for consistency with other models\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n        )\n\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, depth)\n        ]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList(\n            [\n                Block(\n                    dim=embed_dim,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    use_grad_checkpointing=(\n                        use_grad_checkpointing and i >= depth - ckpt_layer\n                    ),\n                )\n                for i in range(depth)\n            ]\n        )\n        self.norm = norm_layer(embed_dim)\n\n        trunc_normal_(self.pos_embed, std=0.02)\n        trunc_normal_(self.cls_token, std=0.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\"pos_embed\", \"cls_token\"}\n\n    def forward(self, x, register_blk=-1):\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        x = x + self.pos_embed[:, : x.size(1), :]\n        x = self.pos_drop(x)\n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x, register_blk == i)\n        x = self.norm(x)\n\n        return x\n\n    @torch.jit.ignore()\n    def load_pretrained(self, checkpoint_path, prefix=\"\"):\n        _load_weights(self, checkpoint_path, prefix)\n\n\n@torch.no_grad()\ndef _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = \"\"):\n    \"\"\"Load weights from .npz checkpoints for official Google Brain Flax implementation\"\"\"\n    import numpy as np\n\n    def _n2p(w, t=True):\n        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:\n            w = w.flatten()\n        if t:\n            if w.ndim == 4:\n                w = w.transpose([3, 2, 0, 1])\n            elif w.ndim == 3:\n                w = w.transpose([2, 0, 1])\n            elif w.ndim == 2:\n                w = w.transpose([1, 0])\n        return torch.from_numpy(w)\n\n    w = np.load(checkpoint_path)\n    if not prefix and \"opt/target/embedding/kernel\" in w:\n        prefix = \"opt/target/\"\n\n    if hasattr(model.patch_embed, \"backbone\"):\n        # hybrid\n        backbone = model.patch_embed.backbone\n        stem_only = not hasattr(backbone, \"stem\")\n        stem = backbone if stem_only else backbone.stem\n        stem.conv.weight.copy_(\n            adapt_input_conv(\n                stem.conv.weight.shape[1], _n2p(w[f\"{prefix}conv_root/kernel\"])\n            )\n        )\n        stem.norm.weight.copy_(_n2p(w[f\"{prefix}gn_root/scale\"]))\n        stem.norm.bias.copy_(_n2p(w[f\"{prefix}gn_root/bias\"]))\n        if not stem_only:\n            for i, stage in enumerate(backbone.stages):\n                for j, block in enumerate(stage.blocks):\n                    bp = f\"{prefix}block{i + 1}/unit{j + 1}/\"\n                    for r in range(3):\n                        getattr(block, f\"conv{r + 1}\").weight.copy_(\n                            _n2p(w[f\"{bp}conv{r + 1}/kernel\"])\n                        )\n                        getattr(block, f\"norm{r + 1}\").weight.copy_(\n                            _n2p(w[f\"{bp}gn{r + 1}/scale\"])\n                        )\n                        getattr(block, f\"norm{r + 1}\").bias.copy_(\n                            _n2p(w[f\"{bp}gn{r + 1}/bias\"])\n                        )\n                    if block.downsample is not None:\n                        block.downsample.conv.weight.copy_(\n                            _n2p(w[f\"{bp}conv_proj/kernel\"])\n                        )\n                        block.downsample.norm.weight.copy_(\n                            _n2p(w[f\"{bp}gn_proj/scale\"])\n                        )\n                        block.downsample.norm.bias.copy_(_n2p(w[f\"{bp}gn_proj/bias\"]))\n        embed_conv_w = _n2p(w[f\"{prefix}embedding/kernel\"])\n    else:\n        embed_conv_w = adapt_input_conv(\n            model.patch_embed.proj.weight.shape[1], _n2p(w[f\"{prefix}embedding/kernel\"])\n        )\n    model.patch_embed.proj.weight.copy_(embed_conv_w)\n    model.patch_embed.proj.bias.copy_(_n2p(w[f\"{prefix}embedding/bias\"]))\n    model.cls_token.copy_(_n2p(w[f\"{prefix}cls\"], t=False))\n    pos_embed_w = _n2p(w[f\"{prefix}Transformer/posembed_input/pos_embedding\"], t=False)\n    if pos_embed_w.shape != model.pos_embed.shape:\n        pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights\n            pos_embed_w,\n            model.pos_embed,\n            getattr(model, \"num_tokens\", 1),\n            model.patch_embed.grid_size,\n        )\n    model.pos_embed.copy_(pos_embed_w)\n    model.norm.weight.copy_(_n2p(w[f\"{prefix}Transformer/encoder_norm/scale\"]))\n    model.norm.bias.copy_(_n2p(w[f\"{prefix}Transformer/encoder_norm/bias\"]))\n    #     if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:\n    #         model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))\n    #         model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))\n    #     if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:\n    #         model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))\n    #         model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))\n    for i, block in enumerate(model.blocks.children()):\n        block_prefix = f\"{prefix}Transformer/encoderblock_{i}/\"\n        mha_prefix = block_prefix + \"MultiHeadDotProductAttention_1/\"\n        block.norm1.weight.copy_(_n2p(w[f\"{block_prefix}LayerNorm_0/scale\"]))\n        block.norm1.bias.copy_(_n2p(w[f\"{block_prefix}LayerNorm_0/bias\"]))\n        block.attn.qkv.weight.copy_(\n            torch.cat(\n                [\n                    _n2p(w[f\"{mha_prefix}{n}/kernel\"], t=False).flatten(1).T\n                    for n in (\"query\", \"key\", \"value\")\n                ]\n            )\n        )\n        block.attn.qkv.bias.copy_(\n            torch.cat(\n                [\n                    _n2p(w[f\"{mha_prefix}{n}/bias\"], t=False).reshape(-1)\n                    for n in (\"query\", \"key\", \"value\")\n                ]\n            )\n        )\n        block.attn.proj.weight.copy_(_n2p(w[f\"{mha_prefix}out/kernel\"]).flatten(1))\n        block.attn.proj.bias.copy_(_n2p(w[f\"{mha_prefix}out/bias\"]))\n        for r in range(2):\n            getattr(block.mlp, f\"fc{r + 1}\").weight.copy_(\n                _n2p(w[f\"{block_prefix}MlpBlock_3/Dense_{r}/kernel\"])\n            )\n            getattr(block.mlp, f\"fc{r + 1}\").bias.copy_(\n                _n2p(w[f\"{block_prefix}MlpBlock_3/Dense_{r}/bias\"])\n            )\n        block.norm2.weight.copy_(_n2p(w[f\"{block_prefix}LayerNorm_2/scale\"]))\n        block.norm2.bias.copy_(_n2p(w[f\"{block_prefix}LayerNorm_2/bias\"]))\n\n\ndef resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    print(\"Resized position embedding: %s to %s\", posemb.shape, posemb_new.shape)\n    ntok_new = posemb_new.shape[1]\n    if num_tokens:\n        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]\n        ntok_new -= num_tokens\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n    if not len(gs_new):  # backwards compatibility\n        gs_new = [int(math.sqrt(ntok_new))] * 2\n    assert len(gs_new) >= 2\n    print(\"Position embedding grid-size from %s to %s\", [gs_old, gs_old], gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(\n        posemb_grid, size=gs_new, mode=\"bicubic\", align_corners=False\n    )\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n    return\n\n\ndef interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):\n    # interpolate position embedding\n    embedding_size = pos_embed_checkpoint.shape[-1]\n    num_patches = visual_encoder.patch_embed.num_patches\n    num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches\n    # height (== width) for the checkpoint position embedding\n    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n    # height (== width) for the new position embedding\n    new_size = int(num_patches**0.5)\n\n    if orig_size != new_size:\n        # class_token and dist_token are kept unchanged\n        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n        # only the position tokens are interpolated\n        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n        pos_tokens = pos_tokens.reshape(\n            -1, orig_size, orig_size, embedding_size\n        ).permute(0, 3, 1, 2)\n        pos_tokens = torch.nn.functional.interpolate(\n            pos_tokens, size=(new_size, new_size), mode=\"bicubic\", align_corners=False\n        )\n        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n        print(\n            \"reshape position embedding from %d to %d\" % (orig_size**2, new_size**2)\n        )\n\n        return new_pos_embed\n    else:\n        return pos_embed_checkpoint\n\n\nclass VisionTransformerEncoder(VisionTransformer, BaseEncoder):\n    @classmethod\n    def from_config(cls, cfg, from_pretrained=False):\n\n        vit_type = cfg.get(\"vit_type\", \"base\")\n        image_size = cfg.get(\"image_size\", 384)\n        ckpt_layer = cfg.get(\"vit_ckpt_layer\", 0)\n        drop_path_rate = cfg.get(\"vit_drop_path_rate\", 0)\n        norm_layer_eps = cfg.get(\"vit_layer_norm_epsilon\", -1)\n        use_grad_checkpointing = cfg.get(\"vit_grad_ckpt\", False)\n\n        if norm_layer_eps == -1:\n            norm_layer = None\n        else:\n            norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)\n\n        #     norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        assert vit_type in [\"base\", \"large\"], \"vit parameter must be base or large\"\n        if vit_type == \"base\":\n            vision_width = 768\n            visual_encoder = cls(\n                img_size=image_size,\n                patch_size=16,\n                embed_dim=vision_width,\n                depth=12,\n                num_heads=12,\n                use_grad_checkpointing=use_grad_checkpointing,\n                ckpt_layer=ckpt_layer,\n                drop_path_rate=0 or drop_path_rate,\n                norm_layer=norm_layer,\n            )\n\n            if from_pretrained:\n                checkpoint = torch.hub.load_state_dict_from_url(\n                    url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth\",\n                    map_location=\"cpu\",\n                    check_hash=True,\n                )\n                state_dict = checkpoint[\"model\"]\n                state_dict[\"pos_embed\"] = interpolate_pos_embed(\n                    state_dict[\"pos_embed\"], visual_encoder\n                )\n                msg = visual_encoder.load_state_dict(state_dict, strict=False)\n\n        elif vit_type == \"large\":\n            vision_width = 1024\n            visual_encoder = cls(\n                img_size=image_size,\n                patch_size=16,\n                embed_dim=vision_width,\n                depth=24,\n                num_heads=16,\n                use_grad_checkpointing=use_grad_checkpointing,\n                ckpt_layer=ckpt_layer,\n                drop_path_rate=0.1 or drop_path_rate,\n                norm_layer=norm_layer,\n            )\n            if from_pretrained:\n                from timm.models.helpers import load_custom_pretrained\n                from timm.models.vision_transformer import default_cfgs\n\n                load_custom_pretrained(\n                    visual_encoder, default_cfgs[\"vit_large_patch16_224_in21k\"]\n                )\n\n        visual_encoder.vision_width = vision_width\n        return visual_encoder\n\n    def forward_features(self, x, register_blk=-1):\n        return super().forward(x, register_blk)\n"
  },
  {
    "path": "lavis/processors/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.processors.base_processor import BaseProcessor\n\nfrom lavis.processors.alpro_processors import (\n    AlproVideoTrainProcessor,\n    AlproVideoEvalProcessor,\n)\nfrom lavis.processors.blip_processors import (\n    BlipImageTrainProcessor,\n    Blip2ImageTrainProcessor,\n    BlipImageEvalProcessor,\n    BlipCaptionProcessor,\n)\nfrom lavis.processors.gpt_processors import (\n    GPTVideoFeatureProcessor,\n    GPTDialogueProcessor,\n)\nfrom lavis.processors.clip_processors import ClipImageTrainProcessor\n\nfrom lavis.common.registry import registry\n\n__all__ = [\n    \"BaseProcessor\",\n    # ALPRO\n    \"AlproVideoTrainProcessor\",\n    \"AlproVideoEvalProcessor\",\n    # BLIP\n    \"BlipImageTrainProcessor\",\n    \"Blip2ImageTrainProcessor\",\n    \"BlipImageEvalProcessor\",\n    \"BlipCaptionProcessor\",\n    \"ClipImageTrainProcessor\",\n    # GPT\n    \"GPTVideoFeatureProcessor\",\n    \"GPTDialogueProcessor\",\n]\n\n\ndef load_processor(name, cfg=None):\n    \"\"\"\n    Example\n\n    >>> processor = load_processor(\"alpro_video_train\", cfg=None)\n    \"\"\"\n    processor = registry.get_processor_class(name).from_config(cfg)\n\n    return processor\n"
  },
  {
    "path": "lavis/processors/alpro_processors.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport torch\nfrom lavis.common.registry import registry\nfrom lavis.datasets.data_utils import load_video\nfrom lavis.processors import transforms_video\nfrom lavis.processors.base_processor import BaseProcessor\nfrom lavis.processors.randaugment import VideoRandomAugment\nfrom lavis.processors import functional_video as F\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\n\nMAX_INT = registry.get(\"MAX_INT\")\n\n\nclass AlproVideoBaseProcessor(BaseProcessor):\n    def __init__(self, mean=None, std=None, n_frms=MAX_INT):\n        if mean is None:\n            mean = (0.48145466, 0.4578275, 0.40821073)\n        if std is None:\n            std = (0.26862954, 0.26130258, 0.27577711)\n\n        self.normalize = transforms_video.NormalizeVideo(mean, std)\n\n        self.n_frms = n_frms\n\n\nclass ToUint8(object):\n    def __init__(self):\n        pass\n\n    def __call__(self, tensor):\n        return tensor.to(torch.uint8)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass ToTHWC(object):\n    \"\"\"\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, tensor):\n        return tensor.permute(1, 2, 3, 0)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass ResizeVideo(object):\n    def __init__(self, target_size, interpolation_mode=\"bilinear\"):\n        self.target_size = target_size\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: central cropping of video clip. Size is\n            (C, T, crop_size, crop_size)\n        \"\"\"\n        return F.resize(clip, self.target_size, self.interpolation_mode)\n\n    def __repr__(self):\n        return self.__class__.__name__ + \"(resize_size={0})\".format(self.target_size)\n\n\n@registry.register_processor(\"alpro_video_train\")\nclass AlproVideoTrainProcessor(AlproVideoBaseProcessor):\n    def __init__(\n        self,\n        image_size=384,\n        mean=None,\n        std=None,\n        min_scale=0.5,\n        max_scale=1.0,\n        n_frms=MAX_INT,\n    ):\n        super().__init__(mean=mean, std=std, n_frms=n_frms)\n\n        self.image_size = image_size\n\n        self.transform = transforms.Compose(\n            [\n                # Video size is (C, T, H, W)\n                transforms_video.RandomResizedCropVideo(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation_mode=\"bicubic\",\n                ),\n                transforms_video.RandomHorizontalFlipVideo(),\n                ToTHWC(),  # C, T, H, W -> T, H, W, C\n                VideoRandomAugment(\n                    2,\n                    5,\n                    augs=[\n                        \"Identity\",\n                        \"AutoContrast\",\n                        \"Brightness\",\n                        \"Sharpness\",\n                        \"Equalize\",\n                        \"ShearX\",\n                        \"ShearY\",\n                        \"TranslateX\",\n                        \"TranslateY\",\n                        \"Rotate\",\n                    ],\n                ),\n                ToUint8(),\n                transforms_video.ToTensorVideo(),  # T, H, W, C -> C, T, H, W\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, vpath):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: video clip after transforms. Size is (C, T, size, size).\n        \"\"\"\n        clip = load_video(\n            video_path=vpath,\n            n_frms=self.n_frms,\n            height=self.image_size,\n            width=self.image_size,\n            sampling=\"headtail\",\n        )\n\n        return self.transform(clip)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 256)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.5)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n\n        n_frms = cfg.get(\"n_frms\", MAX_INT)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n            n_frms=n_frms,\n        )\n\n\n@registry.register_processor(\"alpro_video_eval\")\nclass AlproVideoEvalProcessor(AlproVideoBaseProcessor):\n    def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT):\n        super().__init__(mean=mean, std=std, n_frms=n_frms)\n\n        self.image_size = image_size\n\n        # Input video size is (C, T, H, W)\n        self.transform = transforms.Compose(\n            [\n                # frames will be resized during decord loading.\n                ToUint8(),  # C, T, H, W\n                ToTHWC(),  # T, H, W, C\n                transforms_video.ToTensorVideo(),  # C, T, H, W\n                self.normalize,  # C, T, H, W\n            ]\n        )\n\n    def __call__(self, vpath):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: video clip after transforms. Size is (C, T, size, size).\n        \"\"\"\n        clip = load_video(\n            video_path=vpath,\n            n_frms=self.n_frms,\n            height=self.image_size,\n            width=self.image_size,\n        )\n\n        return self.transform(clip)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 256)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        n_frms = cfg.get(\"n_frms\", MAX_INT)\n\n        return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)\n"
  },
  {
    "path": "lavis/processors/base_processor.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom omegaconf import OmegaConf\n\n\nclass BaseProcessor:\n    def __init__(self):\n        self.transform = lambda x: x\n        return\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        return cls()\n\n    def build(self, **kwargs):\n        cfg = OmegaConf.create(kwargs)\n\n        return self.from_config(cfg)\n"
  },
  {
    "path": "lavis/processors/blip_processors.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport re\nimport torch\nfrom lavis.processors import transforms_video\nfrom lavis.common.registry import registry\nfrom lavis.processors.base_processor import BaseProcessor\nfrom lavis.datasets.data_utils import load_video\nfrom lavis.processors.randaugment import RandomAugment\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import InterpolationMode\n\nMAX_INT = registry.get(\"MAX_INT\")\n\nclass ToUint8(object):\n    def __init__(self):\n        pass\n\n    def __call__(self, tensor):\n        return tensor.to(torch.uint8)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass ToTHWC(object):\n    \"\"\"\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, tensor):\n        return tensor.permute(1, 2, 3, 0)\n\n    def __repr__(self):\n        return self.__class__.__name__\n    \nclass BlipImageBaseProcessor(BaseProcessor):\n    def __init__(self, mean=None, std=None):\n        if mean is None:\n            mean = (0.48145466, 0.4578275, 0.40821073)\n        if std is None:\n            std = (0.26862954, 0.26130258, 0.27577711)\n\n        self.normalize = transforms.Normalize(mean, std)\n\nclass BlipVideoBaseProcessor(BaseProcessor):\n    def __init__(self, mean=None, std=None, n_frms=MAX_INT):\n        if mean is None:\n            mean = (0.48145466, 0.4578275, 0.40821073)\n        if std is None:\n            std = (0.26862954, 0.26130258, 0.27577711)\n\n        self.normalize = transforms_video.NormalizeVideo(mean, std)\n\n        self.n_frms = n_frms\n\n@registry.register_processor(\"blip_caption\")\nclass BlipCaptionProcessor(BaseProcessor):\n    def __init__(self, prompt=\"\", max_words=50):\n        self.prompt = prompt\n        self.max_words = max_words\n\n    def __call__(self, caption):\n        caption = self.prompt + self.pre_caption(caption)\n\n        return caption\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_words = cfg.get(\"max_words\", 50)\n\n        return cls(prompt=prompt, max_words=max_words)\n\n    def pre_caption(self, caption):\n        caption = re.sub(\n            r\"([.!\\\"()*#:;~])\",\n            \" \",\n            caption.lower(),\n        )\n        caption = re.sub(\n            r\"\\s{2,}\",\n            \" \",\n            caption,\n        )\n        caption = caption.rstrip(\"\\n\")\n        caption = caption.strip(\" \")\n\n        # truncate caption\n        caption_words = caption.split(\" \")\n        if len(caption_words) > self.max_words:\n            caption = \" \".join(caption_words[: self.max_words])\n\n        return caption\n\n@registry.register_processor(\"blip_question\")\nclass BlipQuestionProcessor(BaseProcessor):\n    def __init__(self, max_words=50):\n        self.max_words = max_words\n\n    def __call__(self, question):\n        return self.pre_question(question)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        max_words = cfg.get(\"max_words\", 50)\n\n        return cls(max_words=max_words)\n\n    def pre_question(self, question):\n        question = re.sub(\n            r\"([.!\\\"()*#:;~])\",\n            \"\",\n            question.lower(),\n        )\n        question = question.rstrip(\" \")\n\n        # truncate question\n        question_words = question.split(\" \")\n        if len(question_words) > self.max_words:\n            question = \" \".join(question_words[: self.max_words])\n\n        return question\n\n\n\n@registry.register_processor(\"blip_image_train\")\nclass BlipImageTrainProcessor(BlipImageBaseProcessor):\n    def __init__(\n        self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0\n    ):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                transforms.RandomHorizontalFlip(),\n                RandomAugment(\n                    2,\n                    5,\n                    isPIL=True,\n                    augs=[\n                        \"Identity\",\n                        \"AutoContrast\",\n                        \"Brightness\",\n                        \"Sharpness\",\n                        \"Equalize\",\n                        \"ShearX\",\n                        \"ShearY\",\n                        \"TranslateX\",\n                        \"TranslateY\",\n                        \"Rotate\",\n                    ],\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 384)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.5)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n        )\n\n\n@registry.register_processor(\"blip_image_eval\")\nclass BlipImageEvalProcessor(BlipImageBaseProcessor):\n    def __init__(self, image_size=384, mean=None, std=None):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(\n                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 384)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        return cls(image_size=image_size, mean=mean, std=std)\n\n\n@registry.register_processor(\"blip2_image_train\")\nclass Blip2ImageTrainProcessor(BlipImageBaseProcessor):\n    def __init__(\n        self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0\n    ):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 364)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.5)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n        )\n\n@registry.register_processor(\"blip2_video_train\")\nclass Blip2VideoTrainProcessor(BlipVideoBaseProcessor):\n    def __init__(\n        self, \n        image_size=384,\n        mean=None,\n        std=None,\n        min_scale=0.5,\n        max_scale=1.0,\n        n_frms=MAX_INT,\n    ):\n        super().__init__(mean=mean, std=std, n_frms=n_frms)\n\n        self.image_size = image_size\n\n        self.transform = transforms.Compose(\n            [\n                # Video size is (C, T, H, W)\n                transforms_video.RandomResizedCropVideo(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation_mode=\"bicubic\",\n                ),\n                ToTHWC(),  # C, T, H, W -> T, H, W, C\n                ToUint8(),\n                transforms_video.ToTensorVideo(),  # T, H, W, C -> C, T, H, W\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, vpath, clip_proposal=None):\n\n        clip, indices, fps = load_video(\n            video_path=vpath,\n            n_frms=self.n_frms,\n            height=self.image_size,\n            width=self.image_size,\n            sampling=\"random\",\n            clip_proposal=clip_proposal\n        )\n\n        return self.transform(clip), indices, fps\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 364)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.5)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n        n_frms = cfg.get(\"n_frms\", MAX_INT)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n            n_frms=n_frms\n        )\n\n\n@registry.register_processor(\"blip_video_eval\")\nclass BlipVideoEvalProcessor(BlipVideoBaseProcessor):\n    def __init__(self, image_size=384, mean=None, std=None, n_frms=MAX_INT):\n        super().__init__(mean=mean, std=std, n_frms=n_frms)\n\n        self.image_size = image_size\n        self.transform = transforms.Compose(\n            [\n                ToUint8(),  # C, T, H, W\n                ToTHWC(),  # T, H, W, C\n                transforms_video.ToTensorVideo(),  # C, T, H, W\n                self.normalize,  # C, T, H, W\n            ]\n        )\n        self.n_frms = n_frms\n\n    def __call__(self, vpath, clip_proposal=None):\n        clip, indices, fps = load_video(\n            video_path=vpath,\n            n_frms=self.n_frms,\n            height=self.image_size,\n            width=self.image_size,\n            sampling=\"uniform\",\n            clip_proposal=clip_proposal\n        )\n\n        return self.transform(clip), indices, fps\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 256)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        n_frms = cfg.get(\"n_frms\", MAX_INT)\n\n        return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)"
  },
  {
    "path": "lavis/processors/clip_processors.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.processors.blip_processors import BlipImageBaseProcessor\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import InterpolationMode\n\n\ndef _convert_to_rgb(image):\n    return image.convert(\"RGB\")\n\n\n@registry.register_processor(\"clip_image_train\")\nclass ClipImageTrainProcessor(BlipImageBaseProcessor):\n    def __init__(\n        self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0\n    ):\n\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                _convert_to_rgb,\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 224)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.9)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n        )\n\n\n@registry.register_processor(\"clip_image_eval\")\nclass ClipImageEvalProcessor(BlipImageBaseProcessor):\n    def __init__(self, image_size=224, mean=None, std=None):\n\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n                transforms.CenterCrop(image_size),\n                _convert_to_rgb,\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 224)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n        )\n"
  },
  {
    "path": "lavis/processors/functional_video.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport warnings\n\nimport torch\n\n\ndef _is_tensor_video_clip(clip):\n    if not torch.is_tensor(clip):\n        raise TypeError(\"clip should be Tensor. Got %s\" % type(clip))\n\n    if not clip.ndimension() == 4:\n        raise ValueError(\"clip should be 4D. Got %dD\" % clip.dim())\n\n    return True\n\n\ndef crop(clip, i, j, h, w):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n    \"\"\"\n    if len(clip.size()) != 4:\n        raise ValueError(\"clip should be a 4D tensor\")\n    return clip[..., i : i + h, j : j + w]\n\n\ndef resize(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(\n            f\"target size should be tuple (height, width), instead got {target_size}\"\n        )\n    return torch.nn.functional.interpolate(\n        clip, size=target_size, mode=interpolation_mode, align_corners=False\n    )\n\n\ndef resized_crop(clip, i, j, h, w, size, interpolation_mode=\"bilinear\"):\n    \"\"\"\n    Do spatial cropping and resizing to the video clip\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        i (int): i in (i,j) i.e coordinates of the upper left corner.\n        j (int): j in (i,j) i.e coordinates of the upper left corner.\n        h (int): Height of the cropped region.\n        w (int): Width of the cropped region.\n        size (tuple(int, int)): height and width of resized clip\n    Returns:\n        clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    clip = crop(clip, i, j, h, w)\n    clip = resize(clip, size, interpolation_mode)\n    return clip\n\n\ndef center_crop(clip, crop_size):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    th, tw = crop_size\n    if h < th or w < tw:\n        raise ValueError(\"height and width must be no smaller than crop_size\")\n\n    i = int(round((h - th) / 2.0))\n    j = int(round((w - tw) / 2.0))\n    return crop(clip, i, j, th, tw)\n\n\ndef to_tensor(clip):\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)\n    \"\"\"\n    _is_tensor_video_clip(clip)\n    if not clip.dtype == torch.uint8:\n        raise TypeError(\n            \"clip tensor should have data type uint8. Got %s\" % str(clip.dtype)\n        )\n    return clip.float().permute(3, 0, 1, 2) / 255.0\n\n\ndef normalize(clip, mean, std, inplace=False):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)\n        mean (tuple): pixel RGB mean. Size is (3)\n        std (tuple): pixel standard deviation. Size is (3)\n    Returns:\n        normalized clip (torch.tensor): Size is (C, T, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    if not inplace:\n        clip = clip.clone()\n    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)\n    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)\n    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])\n    return clip\n\n\ndef hflip(clip):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)\n    Returns:\n        flipped clip (torch.tensor): Size is (C, T, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    return clip.flip(-1)\n"
  },
  {
    "path": "lavis/processors/gpt_processors.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport re\n\nfrom lavis.common.registry import registry\nfrom lavis.processors.base_processor import BaseProcessor\nfrom lavis.processors.randaugment import RandomAugment\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import InterpolationMode\nimport os\nfrom itertools import chain\nimport numpy as np\nimport torch\nfrom transformers import GPT2Tokenizer\n\nSPECIAL_TOKENS_DICT = {\n    \"bos_token\": \"<bos>\",\n    \"eos_token\": \"<eos>\",\n    \"additional_special_tokens\": [\"<speaker1>\", \"<speaker2>\", \"<video>\", \"<cap>\"],\n    \"pad_token\": \"<pad>\",\n}\nSPECIAL_TOKENS = [\n    \"<bos>\",\n    \"<eos>\",\n    \"<speaker1>\",\n    \"<speaker2>\",\n    \"<cap>\",\n    \"<video>\",\n    \"<pad>\",\n]\n\n\nclass GPTVideoFeatureBaseProcessor(BaseProcessor):\n    def __init__(self, visual_ft=[\"i3d_rgb\"], audio_ft=[\"vggish\"]):\n        self.visual_ft = visual_ft\n        self.audio_ft = audio_ft\n\n\n@registry.register_processor(\"gpt_dialogue\")\nclass GPTDialogueProcessor(BaseProcessor):\n    def __init__(self, max_turns=3, use_caption=True):\n        self.max_turns = max_turns\n        self.use_caption = use_caption\n        self.tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)\n\n    def sample_sequence(self, caption, history, answer):\n        bos, eos, speaker1, speaker2, cap = self.tokenizer.convert_tokens_to_ids(\n            SPECIAL_TOKENS[:-2]\n        )\n        instance = {}\n        sequence = [caption] + history + [answer]\n        sequence = [s + [eos] for s in sequence]\n\n        instance[\"input_ids\"] = list(chain(*sequence))\n        instance[\"token_type_ids\"] = [cap] * len(sequence[0]) + [\n            speaker2 if i % 2 else speaker1\n            for i, s in enumerate(sequence[1:])\n            for _ in s\n        ]\n        instance[\"labels\"] = ([-1] * sum(len(s) for s in sequence[:-1])) + sequence[-1]\n\n        assert len(instance[\"input_ids\"]) == len(instance[\"token_type_ids\"])\n        assert len(instance[\"token_type_ids\"]) == len(instance[\"labels\"])\n\n        for k, v in instance.items():\n            instance[k] = torch.Tensor(v).long()\n\n        return instance\n\n    def padding(self, seq, pad_token=-1):\n        if pad_token == -1:\n            pad_token = self.tokenizer.pad_token_id\n        padded_seq = torch.nn.utils.rnn.pad_sequence(\n            seq, batch_first=True, padding_value=pad_token\n        )\n        return padded_seq\n\n    def get_attention_mask(self, seq, pad_token=-1):\n        if pad_token == -1:\n            pad_token = self.tokenizer.pad_token_id\n        return seq != pad_token\n\n    def __call__(self, ann):\n        if self.use_caption:\n            caption = \" \".join([ann[\"caption\"], ann[\"summary\"]])\n            caption = self.tokenizer.encode(caption)\n        else:\n            caption = []\n\n        dial_history = []\n        for turn in ann[\"dialog\"][-self.max_turns :]:\n            dial_history.append(turn[\"question\"])\n            dial_history.append(turn[\"answer\"])\n        dial_history.append(ann[\"question\"])\n        dial_history = [self.tokenizer.encode(t) for t in dial_history]\n\n        answer = self.tokenizer.encode(ann[\"answer\"])\n\n        item = self.sample_sequence(caption, dial_history, answer)\n\n        return item\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        use_caption = cfg.get(\"use_caption\", True)\n        max_turns = cfg.get(\"max_turns\", 3)\n\n        return cls(max_turns=max_turns, use_caption=use_caption)\n\n\n@registry.register_processor(\"gpt_video_ft\")\nclass GPTVideoFeatureProcessor(GPTVideoFeatureBaseProcessor):\n    def __init__(self, visual_ft, audio_ft):\n        super().__init__(visual_ft, audio_ft)\n        self.tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)\n\n    def padding(self, seq):\n        padded_seq = torch.nn.utils.rnn.pad_sequence(\n            seq, batch_first=True, padding_value=1.0\n        )\n        return padded_seq\n\n    def get_attention_mask(self, seq):\n        return torch.sum(seq != 1, dim=2) != 0\n\n    def __call__(self, ft_root, vname):\n        all_ft = []\n\n        for ft_name in self.visual_ft:\n            ft_path = os.path.join(ft_root, ft_name, vname)\n            all_ft.append(np.load(ft_path + \".npy\"))\n\n        for ft_name in self.audio_ft:\n            ft_path = os.path.join(ft_root, ft_name, vname)\n            all_ft.append(np.load(ft_path + \".npy\"))\n\n        min_len = min([len(ft) for ft in all_ft])\n\n        # TODO: use other sampling method (e.g. uniform sampling)\n        sampled_ft = [ft[:min_len] for ft in all_ft]\n        sampled_ft = np.concatenate(sampled_ft, axis=1)\n        item = {}\n        item[\"video_fts\"] = torch.Tensor(sampled_ft)\n\n        video_type_token = self.tokenizer.convert_tokens_to_ids(\"<video>\")\n        item[\"token_type_ids\"] = torch.Tensor(\n            [video_type_token] * len(sampled_ft)\n        ).long()\n\n        return item\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        visual_ft = cfg.get(\"visual_ft\", [\"i3d_rgb\"])\n        audio_ft = cfg.get(\"audio_ft\", [\"vggish\"])\n\n        return cls(visual_ft=visual_ft, audio_ft=audio_ft)\n"
  },
  {
    "path": "lavis/processors/randaugment.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport cv2\nimport numpy as np\n\nimport torch\n\n\n## aug functions\ndef identity_func(img):\n    return img\n\n\ndef autocontrast_func(img, cutoff=0):\n    \"\"\"\n    same output as PIL.ImageOps.autocontrast\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        n = ch.size\n        cut = cutoff * n // 100\n        if cut == 0:\n            high, low = ch.max(), ch.min()\n        else:\n            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n            low = np.argwhere(np.cumsum(hist) > cut)\n            low = 0 if low.shape[0] == 0 else low[0]\n            high = np.argwhere(np.cumsum(hist[::-1]) > cut)\n            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]\n        if high <= low:\n            table = np.arange(n_bins)\n        else:\n            scale = (n_bins - 1) / (high - low)\n            offset = -low * scale\n            table = np.arange(n_bins) * scale + offset\n            table[table < 0] = 0\n            table[table > n_bins - 1] = n_bins - 1\n        table = table.clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef equalize_func(img):\n    \"\"\"\n    same output as PIL.ImageOps.equalize\n    PIL's implementation is different from cv2.equalize\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n        non_zero_hist = hist[hist != 0].reshape(-1)\n        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)\n        if step == 0:\n            return ch\n        n = np.empty_like(hist)\n        n[0] = step // 2\n        n[1:] = hist[:-1]\n        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef rotate_func(img, degree, fill=(0, 0, 0)):\n    \"\"\"\n    like PIL, rotate by degree, not radians\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    center = W / 2, H / 2\n    M = cv2.getRotationMatrix2D(center, degree, 1)\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)\n    return out\n\n\ndef solarize_func(img, thresh=128):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    table = np.array([el if el < thresh else 255 - el for el in range(256)])\n    table = table.clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef color_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Color\n    \"\"\"\n    ## implementation according to PIL definition, quite slow\n    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]\n    #  out = blend(degenerate, img, factor)\n    #  M = (\n    #      np.eye(3) * factor\n    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)\n    #  )[np.newaxis, np.newaxis, :]\n    M = np.float32(\n        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]\n    ) * factor + np.float32([[0.114], [0.587], [0.299]])\n    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)\n    return out\n\n\ndef contrast_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))\n    table = (\n        np.array([(el - mean) * factor + mean for el in range(256)])\n        .clip(0, 255)\n        .astype(np.uint8)\n    )\n    out = table[img]\n    return out\n\n\ndef brightness_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef sharpness_func(img, factor):\n    \"\"\"\n    The differences the this result and PIL are all on the 4 boundaries, the center\n    areas are same\n    \"\"\"\n    kernel = np.ones((3, 3), dtype=np.float32)\n    kernel[1][1] = 5\n    kernel /= 13\n    degenerate = cv2.filter2D(img, -1, kernel)\n    if factor == 0.0:\n        out = degenerate\n    elif factor == 1.0:\n        out = img\n    else:\n        out = img.astype(np.float32)\n        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]\n        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)\n        out = out.astype(np.uint8)\n    return out\n\n\ndef shear_x_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, factor, 0], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_x_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, -offset], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_y_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [0, 1, -offset]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef posterize_func(img, bits):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))\n    return out\n\n\ndef shear_y_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [factor, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef cutout_func(img, pad_size, replace=(0, 0, 0)):\n    replace = np.array(replace, dtype=np.uint8)\n    H, W = img.shape[0], img.shape[1]\n    rh, rw = np.random.random(2)\n    pad_size = pad_size // 2\n    ch, cw = int(rh * H), int(rw * W)\n    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)\n    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)\n    out = img.copy()\n    out[x1:x2, y1:y2, :] = replace\n    return out\n\n\n### level to args\ndef enhance_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        return ((level / MAX_LEVEL) * 1.8 + 0.1,)\n\n    return level_to_args\n\n\ndef shear_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 0.3\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef translate_level_to_args(translate_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * float(translate_const)\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * cutout_const)\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef solarize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 256)\n        return (level,)\n\n    return level_to_args\n\n\ndef none_level_to_args(level):\n    return ()\n\n\ndef posterize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 4)\n        return (level,)\n\n    return level_to_args\n\n\ndef rotate_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 30\n        if np.random.random() < 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\nfunc_dict = {\n    \"Identity\": identity_func,\n    \"AutoContrast\": autocontrast_func,\n    \"Equalize\": equalize_func,\n    \"Rotate\": rotate_func,\n    \"Solarize\": solarize_func,\n    \"Color\": color_func,\n    \"Contrast\": contrast_func,\n    \"Brightness\": brightness_func,\n    \"Sharpness\": sharpness_func,\n    \"ShearX\": shear_x_func,\n    \"TranslateX\": translate_x_func,\n    \"TranslateY\": translate_y_func,\n    \"Posterize\": posterize_func,\n    \"ShearY\": shear_y_func,\n}\n\ntranslate_const = 10\nMAX_LEVEL = 10\nreplace_value = (128, 128, 128)\narg_dict = {\n    \"Identity\": none_level_to_args,\n    \"AutoContrast\": none_level_to_args,\n    \"Equalize\": none_level_to_args,\n    \"Rotate\": rotate_level_to_args(MAX_LEVEL, replace_value),\n    \"Solarize\": solarize_level_to_args(MAX_LEVEL),\n    \"Color\": enhance_level_to_args(MAX_LEVEL),\n    \"Contrast\": enhance_level_to_args(MAX_LEVEL),\n    \"Brightness\": enhance_level_to_args(MAX_LEVEL),\n    \"Sharpness\": enhance_level_to_args(MAX_LEVEL),\n    \"ShearX\": shear_level_to_args(MAX_LEVEL, replace_value),\n    \"TranslateX\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"TranslateY\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"Posterize\": posterize_level_to_args(MAX_LEVEL),\n    \"ShearY\": shear_level_to_args(MAX_LEVEL, replace_value),\n}\n\n\nclass RandomAugment(object):\n    def __init__(self, N=2, M=10, isPIL=False, augs=[]):\n        self.N = N\n        self.M = M\n        self.isPIL = isPIL\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N)\n        return [(op, 0.5, self.M) for op in sampled_ops]\n\n    def __call__(self, img):\n        if self.isPIL:\n            img = np.array(img)\n        ops = self.get_random_ops()\n        for name, prob, level in ops:\n            if np.random.random() > prob:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return img\n\n\nclass VideoRandomAugment(object):\n    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):\n        self.N = N\n        self.M = M\n        self.p = p\n        self.tensor_in_tensor_out = tensor_in_tensor_out\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N, replace=False)\n        return [(op, self.M) for op in sampled_ops]\n\n    def __call__(self, frames):\n        assert (\n            frames.shape[-1] == 3\n        ), \"Expecting last dimension for 3-channels RGB (b, h, w, c).\"\n\n        if self.tensor_in_tensor_out:\n            frames = frames.numpy().astype(np.uint8)\n\n        num_frames = frames.shape[0]\n\n        ops = num_frames * [self.get_random_ops()]\n        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]\n\n        frames = torch.stack(\n            list(map(self._aug, frames, ops, apply_or_not)), dim=0\n        ).float()\n\n        return frames\n\n    def _aug(self, img, ops, apply_or_not):\n        for i, (name, level) in enumerate(ops):\n            if not apply_or_not[i]:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return torch.from_numpy(img)\n\n\nif __name__ == \"__main__\":\n    a = RandomAugment()\n    img = np.random.randn(32, 32, 3)\n    a(img)\n"
  },
  {
    "path": "lavis/processors/transforms_video.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n\nimport numbers\nimport random\n\nfrom torchvision.transforms import (\n    RandomCrop,\n    RandomResizedCrop,\n)\n\nimport lavis.processors.functional_video as F\n\n\n__all__ = [\n    \"RandomCropVideo\",\n    \"RandomResizedCropVideo\",\n    \"CenterCropVideo\",\n    \"NormalizeVideo\",\n    \"ToTensorVideo\",\n    \"RandomHorizontalFlipVideo\",\n]\n\n\nclass RandomCropVideo(RandomCrop):\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: randomly cropped/resized video clip.\n                size is (C, T, OH, OW)\n        \"\"\"\n        i, j, h, w = self.get_params(clip, self.size)\n        return F.crop(clip, i, j, h, w)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\n\nclass RandomResizedCropVideo(RandomResizedCrop):\n    def __init__(\n        self,\n        size,\n        scale=(0.08, 1.0),\n        ratio=(3.0 / 4.0, 4.0 / 3.0),\n        interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(\n                    f\"size should be tuple (height, width), instead got {size}\"\n                )\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n        self.scale = scale\n        self.ratio = ratio\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: randomly cropped/resized video clip.\n                size is (C, T, H, W)\n        \"\"\"\n        i, j, h, w = self.get_params(clip, self.scale, self.ratio)\n        return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})\"\n\n\nclass CenterCropVideo:\n    def __init__(self, crop_size):\n        if isinstance(crop_size, numbers.Number):\n            self.crop_size = (int(crop_size), int(crop_size))\n        else:\n            self.crop_size = crop_size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)\n        Returns:\n            torch.tensor: central cropping of video clip. Size is\n            (C, T, crop_size, crop_size)\n        \"\"\"\n        return F.center_crop(clip, self.crop_size)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(crop_size={self.crop_size})\"\n\n\nclass NormalizeVideo:\n    \"\"\"\n    Normalize the video clip by mean subtraction and division by standard deviation\n    Args:\n        mean (3-tuple): pixel RGB mean\n        std (3-tuple): pixel RGB standard deviation\n        inplace (boolean): whether do in-place normalization\n    \"\"\"\n\n    def __init__(self, mean, std, inplace=False):\n        self.mean = mean\n        self.std = std\n        self.inplace = inplace\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)\n        \"\"\"\n        return F.normalize(clip, self.mean, self.std, self.inplace)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})\"\n\n\nclass ToTensorVideo:\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)\n        Return:\n            clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)\n        \"\"\"\n        return F.to_tensor(clip)\n\n    def __repr__(self) -> str:\n        return self.__class__.__name__\n\n\nclass RandomHorizontalFlipVideo:\n    \"\"\"\n    Flip the video clip along the horizonal direction with a given probability\n    Args:\n        p (float): probability of the clip being flipped. Default value is 0.5\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Size is (C, T, H, W)\n        Return:\n            clip (torch.tensor): Size is (C, T, H, W)\n        \"\"\"\n        if random.random() < self.p:\n            clip = F.hflip(clip)\n        return clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(p={self.p})\"\n"
  },
  {
    "path": "lavis/projects/albef/eval/nlvr_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_nlvr\n  model_type: nlvr\n\ndatasets:\n  nlvr: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALBEF/NLVR\"\n\n  evaluate: True\n  test_splits: [\"val\", \"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/eval/ret_coco_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  model_type: coco\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/ALBEF/Retrieval_COCO\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/albef/eval/ret_flickr30k_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  model_type: flickr\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/ALBEF/Retrieval_Flickr30k\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/albef/eval/snli_ve_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_classification\n  model_type: ve\n\ndatasets:\n  snli_ve: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALBEF/SNLI_VE\"\n\n  evaluate: True\n  test_splits: [\"val\", \"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/eval/vqa_test.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  model_type: vqav2\n\n  image_size: 384\n\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/ALBEF/VQA\"\n\n  evaluate: True\n  train_splits: [\"train\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/eval/vqa_val.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  model_type: vqav2\n\n  image_size: 384\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/ALBEF/VQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/aokvqa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  model_type: vqav2\n\n  image_size: 384\n\ndatasets:\n  aok_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: aok_vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 1e-6\n  weight_decay: 0.02\n  max_epoch: 6\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 256\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/AOKVQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/nlvr_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_nlvr\n  model_type: nlvr\n  load_finetuned: False\n\ndatasets:\n  nlvr: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 1e-6\n  weight_decay: 0.02\n  warmup_lr: 1e-5\n  warmup_steps: 650\n  max_epoch: 10\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALBEF/NLVR\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/okvqa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  model_type: vqav2\n\n  image_size: 384\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 1e-6\n  weight_decay: 0.02\n  max_epoch: 6\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 256\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/OKVQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/pretrain.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_pretrain\n\n  model_type: base\n  load_pretrained: False\n\n  queue_size: 65536\n\n  image_size: 256\n\n\ndatasets:\n  coco_caption:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  conceptual_caption_3m: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  conceptual_caption_12m: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  vg_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  sbu_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 256\n    text_processor:\n        train:\n          name: \"blip_caption\"\n\nrun:\n  task: image_text_pretrain\n  # optimizer\n  lr_sched: \"linear_warmup_step_lr\"\n  # lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 3e-4\n  min_lr: 1e-6\n  warmup_lr: 1e-6\n  lr_decay_rate: 0.9\n\n  weight_decay: 0.05\n  max_epoch: 20\n  batch_size_train: 64\n  batch_size_eval: 64\n  num_workers: 4\n  warmup_steps: 3000\n\n  seed: 42\n  output_dir: \"output/ALBEF/Pretrain\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/ret_coco_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  model_type: coco\n  load_finetuned: False\n\n  queue_size: 65536\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 1e-6\n  weight_decay: 0.05\n  max_epoch: 5\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 256\n\n  # misc\n  seed: 42\n  output_dir: \"output/ALBEF/Retrieval_COCO\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n"
  },
  {
    "path": "lavis/projects/albef/train/ret_flickr30k_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_retrieval\n  model_type: flickr\n  load_finetuned: False\n\n  queue_size: 65536\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 1e-6\n  weight_decay: 0.02\n  max_epoch: 10\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/ALBEF/Retrieval_Flickr30k\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n"
  },
  {
    "path": "lavis/projects/albef/train/snli_ve_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_classification\n  model_type: ve\n  load_finetuned: False\n  num_classes: 3\n\ndatasets:\n  snli_ve: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 10\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALBEF/SNLI_VE\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/albef/train/vqa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: albef_vqa\n  model_type: vqav2\n  load_finetuned: False\n\n  image_size: 384\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n  vg_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 1e-6\n  weight_decay: 0.02\n  max_epoch: 8\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/ALBEF/VQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/alpro/eval/didemo_ret_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n  model_type: didemo\n\n  max_txt_len: 50\n\n  timesformer:\n    n_frms: 8\n    image_size: 224\n\n\ndatasets:\n  didemo_retrieval: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 8\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimization-specific\n  batch_size_train: 8\n  batch_size_eval: 64\n  num_workers: 4\n\n  # k_test: 256\n  k_test: 1000\n\n  seed: 42\n  output_dir: \"output/ALPRO/didemo_retrieval\"\n\n  evaluate: True\n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n"
  },
  {
    "path": "lavis/projects/alpro/eval/msrvtt_qa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  model_type: msrvtt\n\ndatasets:\n  msrvtt_qa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 16\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALPRO/msrvtt_qa\"\n\n  evaluate: True\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/alpro/eval/msrvtt_ret_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n  model_type: msrvtt\n\ndatasets:\n  msrvtt_retrieval: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 8\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimization-specific\n  batch_size_train: 24\n  batch_size_eval: 64\n  num_workers: 4\n\n  # k_test: 256\n  k_test: 1000\n\n  seed: 42\n  output_dir: \"output/ALPRO/msrvtt_retrieval\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n"
  },
  {
    "path": "lavis/projects/alpro/eval/msvd_qa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  model_type: msvd\n\ndatasets:\n  msvd_qa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 16\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  batch_size_train: 24\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALPRO/msvd_qa\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/alpro/train/didemo_ret_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n  model_type: didemo\n  load_finetuned: False\n\n  max_txt_len: 50\n\n  timesformer:\n    n_frms: 8\n    image_size: 224\n\n\ndatasets:\n  didemo_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"alpro_video_train\"\n          n_frms: 8\n          image_size: 224\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 8\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 3e-5\n  min_lr: 1e-6\n  weight_decay: 1e-4\n  max_epoch: 10\n  batch_size_train: 12\n  batch_size_eval: 32\n  num_workers: 4\n\n  k_test: 256\n  # k_test: 1000\n\n  seed: 42\n  output_dir: \"output/ALPRO/didemo_retrieval\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n"
  },
  {
    "path": "lavis/projects/alpro/train/msrvtt_qa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  model_type: msrvtt\n  load_finetuned: False\n\n  num_classes: 1500\n\n  timesformer:\n    use_grad_ckpt: True\n    ckpt_layer: 12\n\n\ndatasets:\n  msrvtt_qa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"alpro_video_train\"\n          n_frms: 16\n          image_size: 224\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 16\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 1e-6\n  weight_decay: 1e-4\n  max_epoch: 10\n  batch_size_train: 24\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALPRO/msrvtt_qa\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/alpro/train/msrvtt_retrieval_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_retrieval\n  model_type: msrvtt\n  load_finetuned: False\n\ndatasets:\n  msrvtt_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"alpro_video_train\"\n          n_frms: 8\n          image_size: 224\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 8\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 3e-5\n  min_lr: 1e-6\n  weight_decay: 1e-4\n  max_epoch: 5\n  batch_size_train: 8\n  batch_size_eval: 8\n  num_workers: 4\n\n  k_test: 1000\n\n  seed: 42\n  output_dir: \"output/ALPRO/msrvtt_retrieval\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n"
  },
  {
    "path": "lavis/projects/alpro/train/msvd_qa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: alpro_qa\n  model_type: msvd\n  load_finetuned: False\n\n  num_classes: 2423\n\n  timesformer:\n    use_grad_ckpt: True\n    ckpt_layer: 12\n\n\ndatasets:\n  msvd_qa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"alpro_video_train\"\n          n_frms: 16\n          image_size: 224\n        eval:\n          name: \"alpro_video_eval\"\n          n_frms: 16\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 1e-6\n  weight_decay: 1e-4\n  max_epoch: 10\n  batch_size_train: 24\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/ALPRO/msvd_qa\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/coco_cap_ft_iter.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  model_type: large\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  runner: runner_iter\n\n  max_iters: 2e4\n  iters_per_inner_epoch: 2e3\n\n  # task: retrieval\n  task: captioning\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-6\n  min_lr: 0\n  weight_decay: 0.05\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/Caption_coco\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/aokvqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  model_type: aokvqa\n  image_size: 480\n\ndatasets:\n  aok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: aok_vqa\n  # optimization-specific\n  batch_size_train: 64\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/AOKVQA\"\n\n  evaluate: True\n  test_splits: [\"val\", \"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/caption_coco_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  model_type: base_coco\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  # task: retrieval\n  task: captioning\n  # optimizer\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/Caption_coco\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/caption_coco_eval_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  model_type: large_coco\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  # task: retrieval\n  task: captioning\n  # optimizer\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/Caption_coco\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/nlvr_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_nlvr\n  model_type: nlvr\n\ndatasets:\n  nlvr: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/BLIP/NLVR\"\n\n  evaluate: True\n  test_splits: [\"val\", \"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/nocaps_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n  model_type: base_coco\n  # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'\n\ndatasets:\n  nocaps: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n\nrun:\n  # task: retrieval\n  task: captioning\n  # optimizer\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/NoCaps\"\n\n  evaluate: True\n  test_splits: [\"val\", \"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n\n  report_metric: False\n"
  },
  {
    "path": "lavis/projects/blip/eval/okvqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  model_type: okvqa\n  image_size: 480\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/OKVQA\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/ret_coco_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n  model_type: coco\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 128\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 256\n\n  # misc\n  seed: 42\n  output_dir: \"output/BLIP/Retrieval_COCO\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/ret_flickr_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n  model_type: flickr\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/Retrieval_Flickr30k\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/blip/eval/vqav2_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n  model_type: vqav2\n  image_size: 480\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/VQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/aokvqa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n\n  model_type: aokvqa\n  load_finetuned: False\n\n  image_size: 480\n\ndatasets:\n  aok_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: aok_vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 1e-5\n  weight_decay: 0.02\n  max_epoch: 7\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 256\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/AOKVQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/caption_coco_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n\n  model_type: base_coco\n  load_finetuned: False\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  # task: retrieval\n  task: captioning\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/Caption_coco\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/caption_coco_large_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_caption\n\n  model_type: large_coco\n  load_finetuned: False\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n        eval:\n          name: \"blip_image_eval\"\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a picture of \"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: captioning\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-6\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  max_len: 20\n  min_len: 5\n  num_beams: 3\n\n  seed: 42\n  output_dir: \"output/BLIP/Caption_coco\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/nlvr_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_nlvr\n\n  model_type: nlvr\n  load_finetuned: False\n\ndatasets:\n  nlvr: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: multimodal_classification\n\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2.5e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 15\n\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  seed: 42\n  output_dir: \"output/BLIP/NLVR\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/okvqa_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n\n  model_type: okvqa\n  load_finetuned: False\n\n  image_size: 480\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 3e-5\n  min_lr: 1e-5\n  weight_decay: 0.02\n  max_epoch: 7\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 256\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/OKVQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/pretrain_14m.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_pretrain\n\n  model_type: base\n  load_pretrained: False\n\n  queue_size: 57600\n  alpha: 0.4\n\ndatasets:\n  coco_caption:\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  conceptual_caption_3m: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  conceptual_caption_12m: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  vg_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n  sbu_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n\nrun:\n  task: image_text_pretrain\n  # optimizer\n  lr_sched: \"linear_warmup_step_lr\"\n  # lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 3e-4\n  min_lr: 1e-6\n  warmup_lr: 1e-6\n  lr_decay_rate: 0.9\n\n  weight_decay: 0.05\n  max_epoch: 20\n  batch_size_train: 75\n  batch_size_eval: 75\n  num_workers: 4\n  warmup_steps: 3000\n\n  seed: 42\n  output_dir: \"output/BLIP/Pretrain\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip/train/retrieval_coco_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n\n  model_type: coco\n  load_finetuned: False\n\n  queue_size: 57600\n  alpha: 0.4\n\n  negative_all_rank: True\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 0\n  weight_decay: 0.04\n  max_epoch: 6\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 128\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 256\n\n  # misc\n  seed: 42\n  output_dir: \"output/BLIP/Retrieval_COCO\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n"
  },
  {
    "path": "lavis/projects/blip/train/retrieval_flickr_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_retrieval\n\n  model_type: flickr\n  load_finetuned: False\n\n  queue_size: 57600\n\n  negative_all_rank: False\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 384\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 6\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 64\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\", \"test\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/Retrieval_Flickr30k\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n"
  },
  {
    "path": "lavis/projects/blip/train/vqav2_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip_vqa\n\n  model_type: vqav2\n  load_finetuned: False\n\n  image_size: 480\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n        eval:\n          name: \"blip_question\"\n  vg_vqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 480\n    text_processor:\n        train:\n          name: \"blip_question\"\n\nrun:\n  task: vqa\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 2e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 10\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 3\n  num_ans_candidates: 128\n  inference_method: \"rank\"\n\n  seed: 42\n  output_dir: \"output/BLIP/VQA\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/caption_coco_flant5xl_eval.yaml",
    "content": ""
  },
  {
    "path": "lavis/projects/blip2/eval/caption_coco_opt2.7b_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Bleu_1: 0.832\n# Bleu_2: 0.691\n# Bleu_3: 0.556\n# Bleu_4: 0.438\n# METEOR: 0.317\n# ROUGE_L: 0.620\n# CIDEr: 1.461\n# SPICE: 0.252\n\nmodel:\n  arch: blip2_opt\n  model_type: caption_coco_opt2.7b\n  use_grad_checkpoint: False\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: captioning\n  # optimizer\n  batch_size_train: 32\n  batch_size_eval: 16\n  num_workers: 4\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"output/BLIP2/Caption_coco_opt2.7b\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/caption_coco_opt6.7b_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Bleu_1: 0.831\n# Bleu_2: 0.689\n# Bleu_3: 0.552\n# Bleu_4: 0.434\n# METEOR: 0.316\n# ROUGE_L: 0.618\n# CIDEr: 1.451\n# SPICE: 0.251\n\nmodel:\n  arch: blip2_opt\n  model_type: caption_coco_opt6.7b\n  use_grad_checkpoint: False\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: captioning\n  # optimizer\n  batch_size_train: 32\n  batch_size_eval: 16\n  num_workers: 4\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"output/BLIP2/Caption_coco_opt6.7b\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/gqa_zeroshot_flant5xl_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Overall Accuracy is: 43.98\nmodel:\n  arch: blip2_t5\n  model_type: pretrain_flant5xl\n  use_grad_checkpoint: False\n\ndatasets:\n  gqa: # name of the dataset builder\n    type: balanced_testdev\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n    build_info:\n        images:\n            storage: \"/export/share/datasets/vision/GQA/images/\"\n\nrun:\n  task: gqa\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 5\n  inference_method: \"generate\"\n  prompt: \"Question: {} Short answer:\"\n\n  seed: 42\n  output_dir: \"output/BLIP2/GQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/okvqa_zeroshot_flant5xl_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Overall Accuracy is: 41.22\n\nmodel:\n  arch: blip2_t5\n  model_type: pretrain_flant5xl\n  use_grad_checkpoint: False\n\n  # for OKVQA evaluation\n  apply_lemmatizer: True\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: vqa\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 5\n  inference_method: \"generate\"\n  prompt: \"Question: {} Short answer:\"\n\n  seed: 42\n  output_dir: \"output/BLIP2/OKVQA\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/ret_coco_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip2\n  model_type: coco\n  use_grad_checkpoint: False\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 16\n  batch_size_eval: 32\n\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/BLIP2/Retrieval_COCO\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/blip2/eval/ret_flickr_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip2\n  model_type: coco\n  use_grad_checkpoint: False\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 16\n  batch_size_eval: 32\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: False\n\n  # model specific\n  k_test: 128\n\n  # misc\n  seed: 42\n  output_dir: \"output/BLIP2/Retrieval_Flickr30k\"\n\n  evaluate: True"
  },
  {
    "path": "lavis/projects/blip2/eval/vqav2_zeroshot_flant5xl_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Overall Accuracy is: 63.13\n# Per Answer Type Accuracy is the following:\n# other : 52.90\n# yes/no : 84.28\n# number : 41.01\n\nmodel:\n  arch: blip2_t5\n  model_type: pretrain_flant5xl\n  use_grad_checkpoint: False\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: vqa\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 64\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 5\n  inference_method: \"generate\"\n  prompt: \"Question: {} Short answer:\"\n\n  seed: 42\n  output_dir: \"output/BLIP2/VQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/train/caption_coco_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip2_opt\n  model_type: caption_coco_opt2.7b\n  load_finetuned: False\n  use_grad_checkpoint: True\n  freeze_vit: False\n\ndatasets:\n  coco_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 364\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 364\n    text_processor:\n        train:\n          name: \"blip_caption\"\n          prompt: \"a photo of \"\n        eval:\n          name: \"blip_caption\"\n    # build_info:\n    #     images:\n    #         storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: captioning\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 1000\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 4\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"output/BLIP2/Caption_coco\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/blip2/train/pretrain_stage1.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip2\n  model_type: pretrain\n  # TODO: support stage 1 pretraining from scratch (load_pretrained=False does not have effect as of now)\n  load_pretrained: False\n  freeze_vit: True\n\n\ndatasets:\n  coco_caption:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n#     build_info:\n#         images:\n#             storage: '/export/share/datasets/vision/coco/images/'          \n  vg_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n#     build_info:\n#         images:\n#             storage: '//export/share/datasets/vision/visual-genome/image/'\n\nrun:\n  task: image_text_pretrain\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-4\n  min_lr: 1e-5\n  warmup_lr: 1e-6\n\n  weight_decay: 0.05\n  max_epoch: 10\n  batch_size_train: 100\n  batch_size_eval: 64\n  num_workers: 4\n  warmup_steps: 5000\n\n  seed: 42\n  output_dir: \"output/BLIP2/Pretrain_stage1\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True"
  },
  {
    "path": "lavis/projects/blip2/train/pretrain_stage2.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: blip2_opt\n  model_type: pretrain_opt2.7b \n  load_pretrained: True\n  # intialize stage 2 pretraining from stage 1 pretrained model\n  pretrained: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth\"\n  freeze_vit: True\n\n\ndatasets:\n  coco_caption:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n    # build_info:\n    #     images:\n    #         storage: '/export/share/datasets/vision/coco/images/'          \n  vg_caption: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip_image_train\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n    # build_info:\n    #     images:\n    #         storage: '//export/share/datasets/vision/visual-genome/image/'\n\nrun:\n  task: image_text_pretrain\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-4\n  min_lr: 1e-5\n  warmup_lr: 1e-6\n\n  weight_decay: 0.05\n  max_epoch: 10\n  batch_size_train: 64\n  batch_size_eval: 64\n  num_workers: 4\n  warmup_steps: 2000\n\n  seed: 42\n  output_dir: \"output/BLIP2/Pretrain_stage2\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True"
  },
  {
    "path": "lavis/projects/clip/exp_coco_ret_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-L-14-336\n\ndatasets:\n  coco_retrieval: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"clip_image_train\"\n          image_size: 336\n        eval:\n          name: \"clip_image_eval\"\n          image_size: 336\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 128\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: True\n\n  # misc\n  seed: 42\n  output_dir: \"output/clip/Retrieval_COCO\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/clip/exp_flickr_ret_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-L-14-336\n\ndatasets:\n  flickr30k: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"clip_image_train\"\n          image_size: 336\n        eval:\n          name: \"clip_image_eval\"\n          image_size: 336\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n\nrun:\n  task: retrieval\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 128\n\n  test_splits: [\"test\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  use_dist_eval_sampler: True\n\n  # misc\n  seed: 42\n  output_dir: \"output/clip/Retrieval_Flickr\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/clip/exp_imnet_zs_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: clip\n\n  model_type: ViT-L-14-336\n\ndatasets:\n  imagenet: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"clip_image_eval\"\n          # image_size: 224\n          image_size: 336\n\nrun:\n  task: multimodal_classification\n\n  # dataloading\n  num_workers: 4\n  batch_size_train: 32\n  batch_size_eval: 128\n\n  test_splits: [\"val\"]\n\n  # distribution\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n\n  # misc\n  seed: 42\n  output_dir: \"output/clip/zs_imnet\"\n\n  evaluate: True\n"
  },
  {
    "path": "lavis/projects/gpt/eval/dialogue_avsd_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: gpt_dialogue\n  model_type: base\n\ndatasets:\n  avsd_dialogue: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"gpt_video_ft\"\n          visual_ft: [\"i3d_flow\", \"i3d_rgb\"]\n          audio_ft: [\"vggish\"]\n    text_processor:\n        eval:\n          name: \"gpt_dialogue\"\n          max_turns:  3\n          use_caption: True\n\nrun:\n  task: dialogue\n  # optimizer\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 0\n\n  max_len: 20\n  min_len: 5\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"output/gpt2/dialogue_avsd\"\n\n  evaluate: True\n  valid_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/gpt/train/dialogue_avsd_ft.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: gpt_dialogue\n  model_type: base\n\ndatasets:\n  avsd_dialogue: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"gpt_video_ft\"\n          visual_ft: [\"i3d_flow\", \"i3d_rgb\"]\n          audio_ft: [\"vggish\"]\n        eval:\n          name: \"gpt_video_ft\"\n          visual_ft: [\"i3d_flow\", \"i3d_rgb\"]\n          audio_ft: [\"vggish\"]\n    text_processor:\n        train:\n          name: \"gpt_dialogue\"\n          max_turns:  3\n          use_caption: True\n        eval:\n          name: \"gpt_dialogue\"\n          max_turns:  3\n          use_caption: True\n\nrun:\n  task: dialogue\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  weight_decay: 0.05\n  max_epoch: 20\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 0\n\n  max_len: 20\n  min_len: 5\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"output/gpt2/dialogue_avsd\"\n\n  amp: False\n  resume_ckpt_path: null\n\n  evaluate: False \n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/gqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: base\n\ndatasets:\n  gqa: # name of the dataset builder\n    type: balanced_testdev\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: gqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 5\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA/GQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/gqa_eval_3b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: 3b\n\ndatasets:\n  gqa: # name of the dataset builder\n    type: balanced_testdev\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: gqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 4\n  batch_size_eval: 4\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 5\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-3b/GQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/gqa_eval_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: large\n\ndatasets:\n  gqa: # name of the dataset builder\n    type: balanced_testdev\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: gqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 12\n  batch_size_eval: 12\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 5\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-large/GQA\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/okvqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: base\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA/OKVQA\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/okvqa_eval_3b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: 3b\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 4\n  batch_size_eval: 4\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-3b/OKVQA\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/okvqa_eval_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: large\n\ndatasets:\n  ok_vqa: # name of the dataset builder\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 12\n  batch_size_eval: 12\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-large/OKVQA\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: base\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA/VQAv2_val\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_eval_3b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: 3b\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 4\n  batch_size_eval: 4\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-3b/VQAv2_val\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_eval_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: large\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: eval\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 12\n  batch_size_eval: 12\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-large/VQAv2_val\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_test_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: base\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: default\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 16\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA/VQAv2_test\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_test_eval_3b.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: 3b\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: default\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 4\n  batch_size_eval: 4\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-3b/VQAv2_test\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/pnp-vqa/eval/vqav2_test_eval_large.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: pnp_vqa\n  model_type: large\n\ndatasets:\n  coco_vqa: # name of the dataset builder\n    type: default\n    vis_processor:\n        eval:\n          name: \"blip_image_eval\"\n          image_size: 384\n    text_processor:\n        eval:\n          name: \"blip_question\"\n\nrun:\n  task: vqa_reading_comprehension\n\n  # optimization-specific\n  batch_size_train: 12\n  batch_size_eval: 12\n  num_workers: 4\n\n  # image question matching specific\n  block_num: 7\n\n  # image captioning specific\n  top_k: 50\n  top_p: 1\n  cap_min_length: 10\n  cap_max_length: 20\n  repetition_penalty: 1\n  num_patches: 20\n  num_captions: 100\n  prompt: 'a picture of '\n\n  # question answering specific\n  internal_bsz_fid: 1\n  num_captions_fid: 1\n  min_len: 0\n  max_len: 20\n  num_beams: 1\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"output/PNP-VQA-large/VQAv2_test\"\n\n  evaluate: True\n  test_splits: [\"test\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n"
  },
  {
    "path": "lavis/projects/sevila/eval/how2qa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 4\n  task: freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  how2qa:\n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 200\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 32\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/eval/nextqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 5\n  task: freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  nextqa: \n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 200\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 32\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/eval/qvh_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n# Overall Accuracy is: 41.22\n\nmodel:\n  arch: blip2_fmr\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n\ndatasets:\n  qvh:\n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 64\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: moment_retrieval\n  # optimization-specific\n  batch_size_train: 16\n  batch_size_eval: 32\n  num_workers: 4\n\n  # inference-specific\n  max_len: 10\n  min_len: 1\n  num_beams: 5\n  inference_method: \"generate\"\n\n  seed: 42\n  output_dir: \"\"\n\n  evaluate: True\n  test_splits: [\"val\"]\n\n  # distribution-specific\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  ind_unused_parameters: False"
  },
  {
    "path": "lavis/projects/sevila/eval/star_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 4\n  task: freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  star: \n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 200\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 32\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/eval/tvqa_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 5\n  task: freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  tvqa:\n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 200\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 32\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/eval/vlep_eval.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 2\n  task: freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  vlep:\n    vis_processor:\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 1e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 200\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 32\n  batch_size_eval: 32\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/train/how2qa.yaml",
    "content": "model:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 4\n  task: train_loc_freeze_qa_vid\n\ndatasets:\n  how2qa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n    # build_info:\n    #     images:\n    #         storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"/nas-hdd/shoubin/result/BLIP2/NextQA/QA/\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/train/nextqa.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\nmodel:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 5\n  task: train_loc_freeze_qa_vid\n\ndatasets:\n  nextqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n    # build_info:\n    #     images:\n    #         storage: '/export/share/datasets/vision/coco/images/'\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"/nas-hdd/shoubin/result/BLIP2/NextQA/QA/\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/train/qvh.yaml",
    "content": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\n\nmodel:\n  arch: blip2_fmr\n  model_type: pretrain_flant5xl\n  use_grad_checkpoint: False\n  freeze_vit: True\n\ndatasets:\n  qvh: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 64\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\n\nrun:\n  task: moment_retrieval\n  # optimization-specific\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"test\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: False"
  },
  {
    "path": "lavis/projects/sevila/train/star.yaml",
    "content": "model:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 4\n  task: train_loc_freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  star: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"/nas-hdd/shoubin/result/BLIP2/NextQA/QA/\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/train/tvqa.yaml",
    "content": "model:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 5\n  task: train_loc_freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  tvqa: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"/nas-hdd/shoubin/result/BLIP2/NextQA/QA/\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/projects/sevila/train/vlep.yaml",
    "content": "model:\n  arch: sevila\n  model_type: pretrain_flant5xl\n  load_finetuned: True\n  finetuned: 'https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth'\n  use_grad_checkpoint: False\n  freeze_vit: True\n  frame_num: 8\n  answer_num: 2\n  task: train_loc_freeze_qa_vid\n  qformer_input_text: False\n\ndatasets:\n  vlep: # name of the dataset builder\n    vis_processor:\n        train:\n          name: \"blip2_video_train\"\n          n_frms: 4\n          image_size: 224\n        eval:\n          name: \"blip_video_eval\"\n          n_frms: 4\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_question\"\n          max_words: 50\n        eval:\n          name: \"blip_question\"\n          max_words: 50\n\nrun:\n  task: videoqa\n  # optimizer\n  lr_sched: \"linear_warmup_cosine_lr\"\n  init_lr: 5e-5\n  min_lr: 0\n  warmup_lr: 1e-8\n  warmup_steps: 500\n  weight_decay: 0.05\n  max_epoch: 5\n  batch_size_train: 16\n  batch_size_eval: 8\n  num_workers: 8\n  accum_grad_iters: 1\n\n  max_len: 30\n  min_len: 8\n  num_beams: 5\n\n  seed: 42\n  output_dir: \"/nas-hdd/shoubin/result/BLIP2/NextQA/QA/\"\n\n  amp: True\n  resume_ckpt_path: null\n\n  evaluate: False\n  train_splits: [\"train\"]\n  valid_splits: [\"val\"]\n  test_splits: [\"val\"]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  find_unused_parameters: True"
  },
  {
    "path": "lavis/runners/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.runners.runner_base import RunnerBase\nfrom lavis.runners.runner_iter import RunnerIter\n\n__all__ = [\"RunnerBase\", \"RunnerIter\"]\n"
  },
  {
    "path": "lavis/runners/runner_base.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport json\nimport logging\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\nimport webdataset as wds\nfrom lavis.common.dist_utils import (\n    download_cached_file,\n    get_rank,\n    get_world_size,\n    is_main_process,\n    main_process,\n)\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import is_url\nfrom lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split\nfrom lavis.datasets.datasets.dataloader_utils import (\n    IterLoader,\n    MultiIterLoader,\n    PrefetchLoader,\n)\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torch.utils.data.dataset import ChainDataset\n\n\n@registry.register_runner(\"runner_base\")\nclass RunnerBase:\n    \"\"\"\n    A runner class to train and evaluate a model given a task and datasets.\n\n    The runner uses pytorch distributed data parallel by default. Future release\n    will support other distributed frameworks.\n    \"\"\"\n\n    def __init__(self, cfg, task, model, datasets, job_id):\n        self.config = cfg\n        self.job_id = job_id\n\n        self.task = task\n        self.datasets = datasets\n\n        self._model = model\n\n        self._wrapped_model = None\n        self._device = None\n        self._optimizer = None\n        self._scaler = None\n        self._dataloaders = None\n        self._lr_sched = None\n\n        self.start_epoch = 0\n\n        # self.setup_seeds()\n        self.setup_output_dir()\n\n    @property\n    def device(self):\n        if self._device is None:\n            self._device = torch.device(self.config.run_cfg.device)\n\n        return self._device\n\n    @property\n    def use_distributed(self):\n        return self.config.run_cfg.distributed\n\n    @property\n    def model(self):\n        \"\"\"\n        A property to get the DDP-wrapped model on the device.\n        \"\"\"\n        # move model to device\n        if self._model.device != self.device:\n            self._model = self._model.to(self.device)\n\n            # distributed training wrapper\n            if self.use_distributed:\n                if self._wrapped_model is None:\n                    self._wrapped_model = DDP(\n                        self._model, device_ids=[self.config.run_cfg.gpu],\n                        broadcast_buffers=False,\n                        find_unused_parameters=self.config.run_cfg.find_unused_parameters\n                    )\n            else:\n                self._wrapped_model = self._model\n\n        return self._wrapped_model\n\n    @property\n    def optimizer(self):\n        # TODO make optimizer class and configurations\n        if self._optimizer is None:\n            num_parameters = 0\n            p_wd, p_non_wd = [], []\n            for n, p in self.model.named_parameters():\n                if not p.requires_grad:\n                    continue  # frozen weights\n                if p.ndim < 2 or \"bias\" in n or \"ln\" in n or \"bn\" in n:\n                    p_non_wd.append(p)\n                else:\n                    p_wd.append(p)\n                num_parameters += p.data.nelement()\n            logging.info(\"number of trainable parameters: %d\" % num_parameters)\n            optim_params = [\n                {\n                    \"params\": p_wd,\n                    \"weight_decay\": float(self.config.run_cfg.weight_decay),\n                },\n                {\"params\": p_non_wd, \"weight_decay\": 0},\n            ]\n            beta2 = self.config.run_cfg.get(\"beta2\", 0.999)\n            self._optimizer = torch.optim.AdamW(\n                optim_params,\n                lr=float(self.config.run_cfg.init_lr),\n                weight_decay=float(self.config.run_cfg.weight_decay),\n                betas=(0.9, beta2),\n            )\n\n        return self._optimizer\n\n    @property\n    def scaler(self):\n        amp = self.config.run_cfg.get(\"amp\", False)\n\n        if amp:\n            if self._scaler is None:\n                self._scaler = torch.cuda.amp.GradScaler()\n\n        return self._scaler\n\n    @property\n    def lr_scheduler(self):\n        \"\"\"\n        A property to get and create learning rate scheduler by split just in need.\n        \"\"\"\n        if self._lr_sched is None:\n            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)\n\n            # max_epoch = self.config.run_cfg.max_epoch\n            max_epoch = self.max_epoch\n            # min_lr = self.config.run_cfg.min_lr\n            min_lr = self.min_lr\n            # init_lr = self.config.run_cfg.init_lr\n            init_lr = self.init_lr\n\n            # optional parameters\n            decay_rate = self.config.run_cfg.get(\"lr_decay_rate\", None)\n            warmup_start_lr = self.config.run_cfg.get(\"warmup_lr\", -1)\n            warmup_steps = self.config.run_cfg.get(\"warmup_steps\", 0)\n\n            self._lr_sched = lr_sched_cls(\n                optimizer=self.optimizer,\n                max_epoch=max_epoch,\n                min_lr=min_lr,\n                init_lr=init_lr,\n                decay_rate=decay_rate,\n                warmup_start_lr=warmup_start_lr,\n                warmup_steps=warmup_steps,\n            )\n\n        return self._lr_sched\n\n    @property\n    def dataloaders(self) -> dict:\n        \"\"\"\n        A property to get and create dataloaders by split just in need.\n\n        If no train_dataset_ratio is provided, concatenate map-style datasets and\n        chain wds.DataPipe datasets separately. Training set becomes a tuple\n        (ConcatDataset, ChainDataset), both are optional but at least one of them is\n        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.\n\n        If train_dataset_ratio is provided, create a MultiIterLoader to sample\n        each dataset by ratios during training.\n\n        Currently do not support multiple datasets for validation and test.\n\n        Returns:\n            dict: {split_name: (tuples of) dataloader}\n        \"\"\"\n        if self._dataloaders is None:\n            # reoganize datasets by split and concatenate/chain if necessary\n            dataset_ratios = self.config.run_cfg.get(\"train_dataset_ratios\", None)\n\n            # concatenate map-style datasets and chain wds.DataPipe datasets separately\n            # training set becomes a tuple (ConcatDataset, ChainDataset), both are\n            # optional but at least one of them is required. The resultant ConcatDataset\n            # and ChainDataset will be sampled evenly.\n            logging.info(\n                \"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline).\"\n            )\n\n            datasets = reorg_datasets_by_split(self.datasets)\n            self.datasets = concat_datasets(datasets)\n\n            # print dataset statistics after concatenation/chaining\n            for split_name in self.datasets:\n                if isinstance(self.datasets[split_name], tuple) or isinstance(\n                    self.datasets[split_name], list\n                ):\n                    # mixed wds.DataPipeline and torch.utils.data.Dataset\n                    num_records = sum(\n                        [\n                            len(d)\n                            if not type(d) in [wds.DataPipeline, ChainDataset]\n                            else 0\n                            for d in self.datasets[split_name]\n                        ]\n                    )\n\n                else:\n                    if hasattr(self.datasets[split_name], \"__len__\"):\n                        # a single map-style dataset\n                        num_records = len(self.datasets[split_name])\n                    else:\n                        # a single wds.DataPipeline\n                        num_records = -1\n                        logging.info(\n                            \"Only a single wds.DataPipeline dataset, no __len__ attribute.\"\n                        )\n\n                if num_records >= 0:\n                    logging.info(\n                        \"Loaded {} records for {} split from the dataset.\".format(\n                            num_records, split_name\n                        )\n                    )\n\n            # create dataloaders\n            split_names = sorted(self.datasets.keys())\n\n            datasets = [self.datasets[split] for split in split_names]\n            is_trains = [split in self.train_splits for split in split_names]\n\n            batch_sizes = [\n                self.config.run_cfg.batch_size_train\n                if split == \"train\"\n                else self.config.run_cfg.batch_size_eval\n                for split in split_names\n            ]\n\n            collate_fns = []\n            for dataset in datasets:\n                if isinstance(dataset, tuple) or isinstance(dataset, list):\n                    collate_fns.append([getattr(d, \"collater\", None) for d in dataset])\n                else:\n                    collate_fns.append(getattr(dataset, \"collater\", None))\n\n            dataloaders = self.create_loaders(\n                datasets=datasets,\n                num_workers=self.config.run_cfg.num_workers,\n                batch_sizes=batch_sizes,\n                is_trains=is_trains,\n                collate_fns=collate_fns,\n                dataset_ratios=dataset_ratios,\n            )\n\n            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}\n\n        return self._dataloaders\n\n    @property\n    def cuda_enabled(self):\n        return self.device.type == \"cuda\"\n\n    @property\n    def max_epoch(self):\n        return int(self.config.run_cfg.max_epoch)\n\n    @property\n    def log_freq(self):\n        log_freq = self.config.run_cfg.get(\"log_freq\", 50)\n        return int(log_freq)\n\n    @property\n    def init_lr(self):\n        return float(self.config.run_cfg.init_lr)\n\n    @property\n    def min_lr(self):\n        return float(self.config.run_cfg.min_lr)\n\n    @property\n    def accum_grad_iters(self):\n        return int(self.config.run_cfg.get(\"accum_grad_iters\", 1))\n\n    @property\n    def valid_splits(self):\n        valid_splits = self.config.run_cfg.get(\"valid_splits\", [])\n\n        if len(valid_splits) == 0:\n            logging.info(\"No validation splits found.\")\n\n        return valid_splits\n\n    @property\n    def test_splits(self):\n        test_splits = self.config.run_cfg.get(\"test_splits\", [])\n\n        return test_splits\n\n    @property\n    def train_splits(self):\n        train_splits = self.config.run_cfg.get(\"train_splits\", [])\n\n        if len(train_splits) == 0:\n            logging.info(\"Empty train splits.\")\n\n        return train_splits\n\n    @property\n    def evaluate_only(self):\n        \"\"\"\n        Set to True to skip training.\n        \"\"\"\n        return self.config.run_cfg.evaluate\n\n    @property\n    def use_dist_eval_sampler(self):\n        return self.config.run_cfg.get(\"use_dist_eval_sampler\", True)\n\n    @property\n    def resume_ckpt_path(self):\n        return self.config.run_cfg.get(\"resume_ckpt_path\", None)\n\n    @property\n    def train_loader(self):\n        train_dataloader = self.dataloaders[\"train\"]\n\n        return train_dataloader\n\n    def setup_output_dir(self):\n        lib_root = Path(registry.get_path(\"library_root\"))\n\n        output_dir = lib_root / self.config.run_cfg.output_dir # / self.job_id\n        result_dir = output_dir / \"result\"\n\n        output_dir.mkdir(parents=True, exist_ok=True)\n        result_dir.mkdir(parents=True, exist_ok=True)\n\n        registry.register_path(\"result_dir\", str(result_dir))\n        registry.register_path(\"output_dir\", str(output_dir))\n\n        self.result_dir = result_dir\n        self.output_dir = output_dir\n\n    def train(self):\n        start_time = time.time()\n        best_agg_metric = 0\n        best_epoch = 0\n\n        self.log_config()\n\n        # resume from checkpoint if specified\n        if not self.evaluate_only and self.resume_ckpt_path is not None:\n            self._load_checkpoint(self.resume_ckpt_path)\n\n        for cur_epoch in range(self.start_epoch, self.max_epoch):\n            # training phase\n            if not self.evaluate_only:\n                logging.info(\"Start training\")\n                train_stats = self.train_epoch(cur_epoch)\n                self.log_stats(split_name=\"train\", stats=train_stats)\n\n            # evaluation phase\n            if len(self.valid_splits) > 0:\n                for split_name in self.valid_splits:\n                    logging.info(\"Evaluating on {}.\".format(split_name))\n\n                    val_log = self.eval_epoch(\n                        split_name=split_name, cur_epoch=cur_epoch\n                    )\n                    if val_log is not None:\n                        if is_main_process():\n                            assert (\n                                \"agg_metrics\" in val_log\n                            ), \"No agg_metrics found in validation log.\"\n\n                            agg_metrics = val_log[\"agg_metrics\"]\n                            if agg_metrics > best_agg_metric and split_name == \"val\":\n                                best_epoch, best_agg_metric = cur_epoch, agg_metrics\n\n                                self._save_checkpoint(cur_epoch, is_best=True)\n\n                            val_log.update({\"best_epoch\": best_epoch})\n                            self.log_stats(val_log, split_name)\n\n            else:\n                # if no validation split is provided, we just save the checkpoint at the end of each epoch.\n                if not self.evaluate_only:\n                    self._save_checkpoint(cur_epoch, is_best=False)\n\n            if self.evaluate_only:\n                break\n\n            dist.barrier()\n\n        # testing phase\n        test_epoch = \"best\" if len(self.valid_splits) > 0 else cur_epoch\n        self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)\n\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        logging.info(\"Training time {}\".format(total_time_str))\n\n    def evaluate(self, cur_epoch=\"best\", skip_reload=False):\n        test_logs = dict()\n\n        if len(self.test_splits) > 0:\n            for split_name in self.test_splits:\n                test_logs[split_name] = self.eval_epoch(\n                    split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload\n                )\n\n            return test_logs\n\n    def train_epoch(self, epoch):\n        # train\n        self.model.train()\n\n        return self.task.train_epoch(\n            epoch=epoch,\n            model=self.model,\n            data_loader=self.train_loader,\n            optimizer=self.optimizer,\n            scaler=self.scaler,\n            lr_scheduler=self.lr_scheduler,\n            cuda_enabled=self.cuda_enabled,\n            log_freq=self.log_freq,\n            accum_grad_iters=self.accum_grad_iters,\n        )\n\n    @torch.no_grad()\n    def eval_epoch(self, split_name, cur_epoch, skip_reload=False):\n        \"\"\"\n        Evaluate the model on a given split.\n\n        Args:\n            split_name (str): name of the split to evaluate on.\n            cur_epoch (int): current epoch.\n            skip_reload_best (bool): whether to skip reloading the best checkpoint.\n                During training, we will reload the best checkpoint for validation.\n                During testing, we will use provided weights and skip reloading the best checkpoint .\n        \"\"\"\n        data_loader = self.dataloaders.get(split_name, None)\n        assert data_loader, \"data_loader for split {} is None.\".format(split_name)\n\n        # TODO In validation, you need to compute loss as well as metrics\n        # TODO consider moving to model.before_evaluation()\n        model = self.unwrap_dist_model(self.model)\n        if not skip_reload and cur_epoch == \"best\":\n            model = self._reload_best_model(model)\n        model.eval()\n\n        self.task.before_evaluation(\n            model=model,\n            dataset=self.datasets[split_name],\n        )\n        results = self.task.evaluation(model, data_loader)\n        \n        if results is not None:\n            return self.task.after_evaluation(\n                val_result=results,\n                split_name=split_name,\n                epoch=cur_epoch,\n            )\n\n    def unwrap_dist_model(self, model):\n        if self.use_distributed:\n            return model.module\n        else:\n            return model\n\n    def create_loaders(\n        self,\n        datasets,\n        num_workers,\n        batch_sizes,\n        is_trains,\n        collate_fns,\n        dataset_ratios=None,\n    ):\n        \"\"\"\n        Create dataloaders for training and validation.\n        \"\"\"\n\n        def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):\n            # create a single dataloader for each split\n            if isinstance(dataset, ChainDataset) or isinstance(\n                dataset, wds.DataPipeline\n            ):\n                # wds.WebdDataset instance are chained together\n                # webdataset.DataPipeline has its own sampler and collate_fn\n                loader = iter(\n                    DataLoader(\n                        dataset,\n                        batch_size=bsz,\n                        num_workers=num_workers,\n                        pin_memory=True,\n                    )\n                )\n            else:\n                # map-style dataset are concatenated together\n                # setup distributed sampler\n                if self.use_distributed:\n                    sampler = DistributedSampler(\n                        dataset,\n                        shuffle=is_train,\n                        num_replicas=get_world_size(),\n                        rank=get_rank(),\n                    )\n                    if not self.use_dist_eval_sampler:\n                        # e.g. retrieval evaluation\n                        sampler = sampler if is_train else None\n                else:\n                    sampler = None\n\n                loader = DataLoader(\n                    dataset,\n                    batch_size=bsz,\n                    num_workers=num_workers,\n                    pin_memory=True,\n                    sampler=sampler,\n                    shuffle=sampler is None and is_train,\n                    collate_fn=collate_fn,\n                    drop_last=True if is_train else False,\n                )\n                loader = PrefetchLoader(loader)\n\n                if is_train:\n                    loader = IterLoader(loader, use_distributed=self.use_distributed)\n\n            return loader\n\n        loaders = []\n\n        for dataset, bsz, is_train, collate_fn in zip(\n            datasets, batch_sizes, is_trains, collate_fns\n        ):\n            if isinstance(dataset, list) or isinstance(dataset, tuple):\n                loader = MultiIterLoader(\n                    loaders=[\n                        _create_loader(d, num_workers, bsz, is_train, collate_fn[i])\n                        for i, d in enumerate(dataset)\n                    ],\n                    ratios=dataset_ratios,\n                )\n            else:\n                loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)\n\n            loaders.append(loader)\n\n        return loaders\n\n    @main_process\n    def _save_checkpoint(self, cur_epoch, is_best=False):\n        \"\"\"\n        Save the checkpoint at the current epoch.\n        \"\"\"\n        model_no_ddp = self.unwrap_dist_model(self.model)\n        param_grad_dic = {\n            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()\n        }\n        state_dict = model_no_ddp.state_dict()\n        for k in list(state_dict.keys()):\n            if k in param_grad_dic.keys() and not param_grad_dic[k]:\n                # delete parameters that do not require gradient\n                #if 't5_model' not in k and 'visual_encoder' not in k:\n                # print(k)\n                del state_dict[k]\n        save_obj = {\n            \"model\": state_dict,\n            \"optimizer\": self.optimizer.state_dict(),\n            \"config\": self.config.to_dict(),\n            \"scaler\": self.scaler.state_dict() if self.scaler else None,\n            \"epoch\": cur_epoch,\n        }\n        save_to = os.path.join(\n            self.output_dir,\n            \"checkpoint_{}.pth\".format(\"best\" if is_best else cur_epoch),\n        )\n        logging.info(\"Saving checkpoint at epoch {} to {}.\".format(cur_epoch, save_to))\n        torch.save(save_obj, save_to)\n\n    def _reload_best_model(self, model):\n        \"\"\"\n        Load the best checkpoint for evaluation.\n        \"\"\"\n        checkpoint_path = os.path.join(self.output_dir, \"checkpoint_best.pth\")\n\n        logging.info(\"Loading checkpoint from {}.\".format(checkpoint_path))\n        checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n        try:\n            model.load_state_dict(checkpoint[\"model\"])\n        except RuntimeError as e:\n            logging.warning(\n                \"\"\"\n                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.\n                Trying to load the model with strict=False.\n                \"\"\"\n            )\n            model.load_state_dict(checkpoint[\"model\"], strict=False)\n        return model\n\n    def _load_checkpoint(self, url_or_filename):\n        \"\"\"\n        Resume from a checkpoint.\n        \"\"\"\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=self.device)\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=self.device)\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n        self.unwrap_dist_model(self.model).load_state_dict(state_dict)\n\n        self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n        if self.scaler and \"scaler\" in checkpoint:\n            self.scaler.load_state_dict(checkpoint[\"scaler\"])\n\n        self.start_epoch = checkpoint[\"epoch\"] + 1\n        logging.info(\"Resume checkpoint from {}\".format(url_or_filename))\n\n    @main_process\n    def log_stats(self, stats, split_name):\n        if isinstance(stats, dict):\n            log_stats = {**{f\"{split_name}_{k}\": v for k, v in stats.items()}}\n            with open(os.path.join(self.output_dir, \"log.txt\"), \"a\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n        elif isinstance(stats, list):\n            pass\n\n    @main_process\n    def log_config(self):\n        with open(os.path.join(self.output_dir, \"log.txt\"), \"a\") as f:\n            f.write(json.dumps(self.config.to_dict(), indent=4) + \"\\n\")\n"
  },
  {
    "path": "lavis/runners/runner_iter.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport logging\nimport os\nimport time\n\nimport torch\nimport torch.distributed as dist\nimport webdataset as wds\nfrom lavis.common.dist_utils import download_cached_file, is_main_process, main_process\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import is_url\nfrom lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split\nfrom lavis.runners.runner_base import RunnerBase\nfrom torch.utils.data.dataset import ChainDataset\n\n\n@registry.register_runner(\"runner_iter\")\nclass RunnerIter(RunnerBase):\n    \"\"\"\n    Run training based on the number of iterations. This is common when\n    the training dataset size is large. Underhood logic is similar to\n    epoch-based training by considering every #iters_per_inner_epoch as an\n    inner epoch.\n\n    In iter-based runner, after every #iters_per_inner_epoch steps, we\n\n        1) do a validation epoch;\n        2) schedule the learning rate;\n        3) save the checkpoint.\n\n    We refer every #iters_per_inner_epoch steps as an inner epoch.\n    \"\"\"\n\n    def __init__(self, cfg, task, model, datasets, job_id):\n        super().__init__(cfg, task, model, datasets, job_id)\n\n        self.start_iters = 0\n\n        self.max_iters = int(self.config.run_cfg.get(\"max_iters\", -1))\n        assert self.max_iters > 0, \"max_iters must be greater than 0.\"\n\n        self.iters_per_inner_epoch = int(\n            self.config.run_cfg.get(\"iters_per_inner_epoch\", -1)\n        )\n        assert (\n            self.iters_per_inner_epoch > 0\n        ), \"iters_per_inner_epoch must be greater than 0.\"\n\n    @property\n    def max_epoch(self):\n        return int(self.max_iters / self.iters_per_inner_epoch)\n\n    @property\n    def cur_epoch(self):\n        try:\n            return self.train_loader.epoch\n        except AttributeError:\n            # pipeline data (e.g. LAION) is streaming, have no concept of epoch\n            return 0\n\n    def _progress(self, cur_iters):\n        return \"{}_iters={}\".format(self.cur_epoch, cur_iters)\n\n    def train(self):\n        start_time = time.time()\n        best_agg_metric = 0\n        best_iters = 0\n\n        self.log_config()\n\n        # resume from checkpoint if specified\n        if not self.evaluate_only and self.resume_ckpt_path is not None:\n            self._load_checkpoint(self.resume_ckpt_path)\n\n        for start_iters in range(\n            self.start_iters, self.max_iters, self.iters_per_inner_epoch\n        ):\n            end_iters = start_iters + self.iters_per_inner_epoch\n\n            # training phase\n            if not self.evaluate_only:\n                logging.info(\n                    \"Start training, max_iters={}, in total {} inner epochs.\".format(\n                        self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)\n                    )\n                )\n\n                train_stats = self.train_iters(self.cur_epoch, start_iters)\n                self.log_stats(split_name=\"train\", stats=train_stats)\n\n            # evaluation phase\n            if len(self.valid_splits) > 0:\n                for split_name in self.valid_splits:\n                    logging.info(\"Evaluating on {}.\".format(split_name))\n\n                    val_log = self.eval_epoch(\n                        split_name=split_name, cur_epoch=self._progress(end_iters)\n                    )\n                    if val_log is not None:\n                        if is_main_process():\n                            assert (\n                                \"agg_metrics\" in val_log\n                            ), \"No agg_metrics found in validation log.\"\n\n                            agg_metrics = val_log[\"agg_metrics\"]\n                            if agg_metrics > best_agg_metric and split_name == \"val\":\n                                best_iters, best_agg_metric = end_iters, agg_metrics\n\n                                self._save_checkpoint(end_iters, is_best=True)\n\n                            val_log.update({\"best_iters\": best_iters})\n                            self.log_stats(val_log, split_name)\n\n            else:\n                # if no validation split is provided, we just save the checkpoint at the end of each inner epoch.\n                if not self.evaluate_only:\n                    self._save_checkpoint(end_iters, is_best=False)\n\n            if self.evaluate_only:\n                break\n            dist.barrier()\n\n        # testing phase\n        self.evaluate(cur_epoch=self.cur_epoch)\n\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        logging.info(\"Training time {}\".format(total_time_str))\n\n    def train_iters(self, epoch, start_iters):\n        # train by iterations\n        self.model.train()\n\n        return self.task.train_iters(\n            epoch=epoch,\n            start_iters=start_iters,\n            iters_per_inner_epoch=self.iters_per_inner_epoch,\n            model=self.model,\n            data_loader=self.train_loader,\n            optimizer=self.optimizer,\n            scaler=self.scaler,\n            lr_scheduler=self.lr_scheduler,\n            cuda_enabled=self.cuda_enabled,\n            log_freq=self.log_freq,\n            accum_grad_iters=self.accum_grad_iters,\n        )\n\n    @main_process\n    def _save_checkpoint(self, cur_iters, is_best=False):\n        save_obj = {\n            \"model\": self.unwrap_dist_model(self.model).state_dict(),\n            \"optimizer\": self.optimizer.state_dict(),\n            \"config\": self.config.to_dict(),\n            \"scaler\": self.scaler.state_dict() if self.scaler else None,\n            \"iters\": cur_iters,\n        }\n        save_to = os.path.join(\n            self.output_dir,\n            \"checkpoint_{}.pth\".format(\"best\" if is_best else cur_iters),\n        )\n        logging.info(\"Saving checkpoint at iters {} to {}.\".format(cur_iters, save_to))\n        torch.save(save_obj, save_to)\n\n    def _load_checkpoint(self, url_or_filename):\n        \"\"\"\n        Resume from a checkpoint.\n        \"\"\"\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=self.device)\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=self.device)\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n        self.unwrap_dist_model(self.model).load_state_dict(state_dict)\n\n        self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n        if self.scaler and \"scaler\" in checkpoint:\n            self.scaler.load_state_dict(checkpoint[\"scaler\"])\n\n        self.start_iters = checkpoint[\"iters\"] + 1\n        logging.info(\"Resume checkpoint from {}\".format(url_or_filename))\n\n    @property\n    def dataloaders(self) -> dict:\n        \"\"\"\n        A property to get and create dataloaders by split just in need.\n\n        If no train_dataset_ratio is provided, concatenate map-style datasets and\n        chain wds.DataPipe datasets separately. Training set becomes a tuple\n        (ConcatDataset, ChainDataset), both are optional but at least one of them is\n        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.\n\n        If train_dataset_ratio is provided, create a MultiIterLoader to sample\n        each dataset by ratios during training.\n\n        Currently do not support multiple datasets for validation and test.\n\n        Returns:\n            dict: {split_name: (tuples of) dataloader}\n        \"\"\"\n        if self._dataloaders is None:\n            # reoganize datasets by split and concatenate/chain if necessary\n            dataset_ratios = self.config.run_cfg.get(\"train_dataset_ratios\", None)\n\n            if dataset_ratios is None:\n                # concatenate map-style datasets and chain wds.DataPipe datasets separately\n                # training set becomes a tuple (ConcatDataset, ChainDataset), both are\n                # optional but at least one of them is required. The resultant ConcatDataset\n                # and ChainDataset will be sampled evenly.\n                logging.info(\n                    \"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline).\"\n                )\n\n                datasets = reorg_datasets_by_split(self.datasets)\n                self.datasets = concat_datasets(datasets)\n            else:\n                # create multi-loader with the provided ratios, without concatenating or chaining\n                missing_keys = [k for k in dataset_ratios if k not in self.datasets]\n                if len(missing_keys) > 0:\n                    raise ValueError(\n                        \"Datasets with the following split names are not found: {}\".format(\n                            missing_keys\n                        )\n                    )\n\n                unexpected_keys = [k for k in self.datasets if k not in dataset_ratios]\n                if len(unexpected_keys) > 0:\n                    raise ValueError(\n                        \"Datasets with the following split names are not expected: {}\".format(\n                            unexpected_keys\n                        )\n                    )\n\n                dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets]\n                self.datasets = reorg_datasets_by_split(self.datasets)\n                # to keep the same structure as return value of concat_datasets\n                self.datasets = {\n                    k: v[0] if len(v) == 1 else v for k, v in datasets.items()\n                }\n\n            # print dataset statistics after concatenation/chaining\n            for split_name in self.datasets:\n                if isinstance(self.datasets[split_name], tuple) or isinstance(\n                    self.datasets[split_name], list\n                ):\n                    # mixed wds.DataPipeline and torch.utils.data.Dataset\n                    num_records = sum(\n                        [\n                            len(d)\n                            if not type(d) in [wds.DataPipeline, ChainDataset]\n                            else 0\n                            for d in self.datasets[split_name]\n                        ]\n                    )\n\n                else:\n                    try:\n                        # a single map-style dataset\n                        num_records = len(self.datasets[split_name])\n                    except TypeError:\n                        # a single wds.DataPipeline or ChainDataset\n                        num_records = -1\n                        logging.info(\n                            \"Only a single wds.DataPipeline dataset, no __len__ attribute.\"\n                        )\n\n                if num_records >= 0:\n                    logging.info(\n                        \"Loaded {} records for {} split from the dataset.\".format(\n                            num_records, split_name\n                        )\n                    )\n\n            # create dataloaders\n            split_names = sorted(self.datasets.keys())\n\n            datasets = [self.datasets[split] for split in split_names]\n            is_trains = [split in self.train_splits for split in split_names]\n\n            batch_sizes = [\n                self.config.run_cfg.batch_size_train\n                if split == \"train\"\n                else self.config.run_cfg.batch_size_eval\n                for split in split_names\n            ]\n\n            collate_fns = []\n            for dataset in datasets:\n                if isinstance(dataset, tuple) or isinstance(dataset, list):\n                    collate_fns.append([getattr(d, \"collater\", None) for d in dataset])\n                else:\n                    collate_fns.append(getattr(dataset, \"collater\", None))\n\n            dataloaders = self.create_loaders(\n                datasets=datasets,\n                num_workers=self.config.run_cfg.num_workers,\n                batch_sizes=batch_sizes,\n                is_trains=is_trains,\n                collate_fns=collate_fns,\n                dataset_ratios=dataset_ratios,\n            )\n\n            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}\n\n        return self._dataloaders\n"
  },
  {
    "path": "lavis/tasks/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\nfrom lavis.tasks.captioning import CaptionTask\nfrom lavis.tasks.image_text_pretrain import ImageTextPretrainTask\nfrom lavis.tasks.multimodal_classification import (\n    MultimodalClassificationTask,\n)\nfrom lavis.tasks.retrieval import RetrievalTask\nfrom lavis.tasks.vqa import VQATask, GQATask, AOKVQATask, VideoQA, FrameQA\nfrom lavis.tasks.vqa_reading_comprehension import VQARCTask, GQARCTask\nfrom lavis.tasks.dialogue import DialogueTask\n\n\ndef setup_task(cfg):\n    assert \"task\" in cfg.run_cfg, \"Task name must be provided.\"\n\n    task_name = cfg.run_cfg.task\n    task = registry.get_task_class(task_name).setup_task(cfg=cfg)\n    assert task is not None, \"Task {} not properly registered.\".format(task_name)\n\n    return task\n\n\n__all__ = [\n    \"BaseTask\",\n    \"AOKVQATask\",\n    \"RetrievalTask\",\n    \"CaptionTask\",\n    \"VQATask\",\n    \"GQATask\",\n    \"VQARCTask\",\n    \"GQARCTask\",\n    \"MultimodalClassificationTask\",\n    # \"VisualEntailmentTask\",\n    \"VideoQA\",\n    \"FrameQA\",\n    \"ImageTextPretrainTask\",\n    \"DialogueTask\",\n]\n"
  },
  {
    "path": "lavis/tasks/base_task.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized\nfrom lavis.common.logger import MetricLogger, SmoothedValue\nfrom lavis.common.registry import registry\nfrom lavis.datasets.data_utils import prepare_sample\n\n\nclass BaseTask:\n    def __init__(self, **kwargs):\n        super().__init__()\n\n        self.inst_id_key = \"instance_id\"\n\n    @classmethod\n    def setup_task(cls, **kwargs):\n        return cls()\n\n    def build_model(self, cfg):\n        model_config = cfg.model_cfg\n\n        model_cls = registry.get_model_class(model_config.arch)\n        return model_cls.from_config(model_config)\n\n    def build_datasets(self, cfg):\n        \"\"\"\n        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.\n        Download dataset and annotations automatically if not exist.\n\n        Args:\n            cfg (common.config.Config): _description_\n\n        Returns:\n            dict: Dictionary of torch.utils.data.Dataset objects by split.\n        \"\"\"\n\n        datasets = dict()\n\n        datasets_config = cfg.datasets_cfg\n        assert len(datasets_config) > 0, \"At least one dataset has to be specified.\"\n\n        for name in datasets_config:\n            dataset_config = datasets_config[name]\n            builder = registry.get_builder_class(name)(dataset_config)\n            dataset = builder.build_datasets()\n\n            datasets[name] = dataset\n\n        return datasets\n\n    def train_step(self, model, samples):\n        loss = model(samples)[\"loss\"]\n        return loss\n\n    def valid_step(self, model, samples):\n        raise NotImplementedError\n\n    def before_evaluation(self, model, dataset, **kwargs):\n        model.before_evaluation(dataset=dataset, task_type=type(self))\n\n    def after_evaluation(self, **kwargs):\n        pass\n\n    def inference_step(self):\n        raise NotImplementedError\n\n    def evaluation(self, model, data_loader, cuda_enabled=True):\n        metric_logger = MetricLogger(delimiter=\"  \")\n        header = \"Evaluation\"\n        # TODO make it configurable\n        print_freq = 10\n\n        results = []\n\n        for samples in metric_logger.log_every(data_loader, print_freq, header):\n            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)\n\n            eval_output = self.valid_step(model=model, samples=samples)\n            results.extend(eval_output)\n            #break\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        return results\n\n    def train_epoch(\n        self,\n        epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        cuda_enabled=False,\n        log_freq=50,\n        accum_grad_iters=1,\n    ):\n        return self._train_inner_loop(\n            epoch=epoch,\n            iters_per_epoch=len(data_loader),\n            model=model,\n            data_loader=data_loader,\n            optimizer=optimizer,\n            scaler=scaler,\n            lr_scheduler=lr_scheduler,\n            log_freq=log_freq,\n            cuda_enabled=cuda_enabled,\n            accum_grad_iters=accum_grad_iters,\n        )\n\n    def train_iters(\n        self,\n        epoch,\n        start_iters,\n        iters_per_inner_epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        cuda_enabled=False,\n        log_freq=50,\n        accum_grad_iters=1,\n    ):\n        return self._train_inner_loop(\n            epoch=epoch,\n            start_iters=start_iters,\n            iters_per_epoch=iters_per_inner_epoch,\n            model=model,\n            data_loader=data_loader,\n            optimizer=optimizer,\n            scaler=scaler,\n            lr_scheduler=lr_scheduler,\n            log_freq=log_freq,\n            cuda_enabled=cuda_enabled,\n            accum_grad_iters=accum_grad_iters,\n        )\n\n    def _train_inner_loop(\n        self,\n        epoch,\n        iters_per_epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        start_iters=None,\n        log_freq=50,\n        cuda_enabled=False,\n        accum_grad_iters=1,\n    ):\n        \"\"\"\n        An inner training loop compatible with both epoch-based and iter-based training.\n\n        When using epoch-based, training stops after one epoch; when using iter-based,\n        training stops after #iters_per_epoch iterations.\n        \"\"\"\n        use_amp = scaler is not None\n\n        if not hasattr(data_loader, \"__next__\"):\n            # convert to iterator if not already\n            data_loader = iter(data_loader)\n\n        metric_logger = MetricLogger(delimiter=\"  \")\n        metric_logger.add_meter(\"lr\", SmoothedValue(window_size=1, fmt=\"{value:.6f}\"))\n        metric_logger.add_meter(\"loss\", SmoothedValue(window_size=1, fmt=\"{value:.4f}\"))\n\n        # if iter-based runner, schedule lr based on inner epoch.\n        logging.info(\n            \"Start training epoch {}, {} iters per inner epoch.\".format(\n                epoch, iters_per_epoch\n            )\n        )\n        header = \"Train: data epoch: [{}]\".format(epoch)\n        if start_iters is None:\n            # epoch-based runner\n            inner_epoch = epoch\n        else:\n            # In iter-based runner, we schedule the learning rate based on iterations.\n            inner_epoch = start_iters // iters_per_epoch\n            header = header + \"; inner epoch [{}]\".format(inner_epoch)\n\n        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):\n            # if using iter-based runner, we stop after iters_per_epoch iterations.\n            if i >= iters_per_epoch:\n                break\n\n            samples = next(data_loader)\n\n            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)\n            samples.update(\n                {\n                    \"epoch\": inner_epoch,\n                    \"num_iters_per_epoch\": iters_per_epoch,\n                    \"iters\": i,\n                }\n            )\n\n            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)\n\n            with torch.cuda.amp.autocast(enabled=use_amp):\n                loss = self.train_step(model=model, samples=samples)\n\n            # after_train_step()\n            if use_amp:\n                scaler.scale(loss).backward()\n            else:\n                loss.backward()\n\n            # update gradients every accum_grad_iters iterations\n            if (i + 1) % accum_grad_iters == 0:\n                if use_amp:\n                    scaler.step(optimizer)\n                    scaler.update()                     \n                else:    \n                    optimizer.step()\n                optimizer.zero_grad()\n\n            metric_logger.update(loss=loss.item())\n            metric_logger.update(lr=optimizer.param_groups[0][\"lr\"])\n\n        # after train_epoch()\n        # gather the stats from all processes\n        metric_logger.synchronize_between_processes()\n        logging.info(\"Averaged stats: \" + str(metric_logger.global_avg()))\n        return {\n            k: \"{:.3f}\".format(meter.global_avg)\n            for k, meter in metric_logger.meters.items()\n        }\n\n    @staticmethod\n    def save_result(result, result_dir, filename, remove_duplicate=\"\"):\n        import json\n\n        result_file = os.path.join(\n            result_dir, \"%s_rank%d.json\" % (filename, get_rank())\n        )\n        final_result_file = os.path.join(result_dir, \"%s.json\" % filename)\n\n        json.dump(result, open(result_file, \"w\"))\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        if is_main_process():\n            logging.warning(\"rank %d starts merging results.\" % get_rank())\n            # combine results from all processes\n            result = []\n\n            for rank in range(get_world_size()):\n                result_file = os.path.join(\n                    result_dir, \"%s_rank%d.json\" % (filename, rank)\n                )\n                res = json.load(open(result_file, \"r\"))\n                result += res\n\n            if remove_duplicate:\n                result_new = []\n                id_list = []\n                for res in result:\n                    if res[remove_duplicate] not in id_list:\n                        id_list.append(res[remove_duplicate])\n                        result_new.append(res)\n                result = result_new\n\n            json.dump(result, open(final_result_file, \"w\"))\n            print(\"result file saved to %s\" % final_result_file)\n\n        return final_result_file\n"
  },
  {
    "path": "lavis/tasks/captioning.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport os\n\nfrom lavis.common.dist_utils import main_process\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\n\n\n@registry.register_task(\"captioning\")\nclass CaptionTask(BaseTask):\n    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):\n        super().__init__()\n\n        self.num_beams = num_beams\n        self.max_len = max_len\n        self.min_len = min_len\n        self.evaluate = evaluate\n\n        self.report_metric = report_metric\n\n    @classmethod\n    def setup_task(cls, cfg):\n        run_cfg = cfg.run_cfg\n\n        num_beams = run_cfg.num_beams\n        max_len = run_cfg.max_len\n        min_len = run_cfg.min_len\n        evaluate = run_cfg.evaluate\n\n        report_metric = run_cfg.get(\"report_metric\", True)\n\n        return cls(\n            num_beams=num_beams,\n            max_len=max_len,\n            min_len=min_len,\n            evaluate=evaluate,\n            report_metric=report_metric,\n        )\n\n    def valid_step(self, model, samples):\n        results = []\n\n        # run_cfg = slf.cfg.run_cfg\n        captions = model.generate(\n            samples,\n            use_nucleus_sampling=False,\n            num_beams=self.num_beams,\n            max_length=self.max_len,\n            min_length=self.min_len,\n        )\n\n        img_ids = samples[\"image_id\"]\n        for caption, img_id in zip(captions, img_ids):\n            results.append({\"caption\": caption, \"image_id\": int(img_id)})\n\n        return results\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n        eval_result_file = self.save_result(\n            result=val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=\"{}_epoch{}\".format(split_name, epoch),\n            remove_duplicate=\"image_id\",\n        )\n\n        if self.report_metric:\n            metrics = self._report_metrics(\n                eval_result_file=eval_result_file, split_name=split_name\n            )\n        else:\n            metrics = {\"agg_metrics\": 0.0}\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n\n        # TODO better way to define this\n        coco_gt_root = os.path.join(registry.get_path(\"cache_root\"), \"coco_gt\")\n        coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name)\n\n        agg_metrics = coco_val.eval[\"CIDEr\"] + coco_val.eval[\"Bleu_4\"]\n        log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        coco_res = {k: v for k, v in coco_val.eval.items()}\n        coco_res[\"agg_metrics\"] = agg_metrics\n\n        return coco_res\n\n\n# TODO better structure for this.\nfrom pycocoevalcap.eval import COCOEvalCap\nfrom pycocotools.coco import COCO\nfrom torchvision.datasets.utils import download_url\n\n\ndef coco_caption_eval(coco_gt_root, results_file, split):\n    urls = {\n        \"val\": \"https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json\",\n        \"test\": \"https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json\",\n    }\n    filenames = {\n        \"val\": \"coco_karpathy_val_gt.json\",\n        \"test\": \"coco_karpathy_test_gt.json\",\n    }\n\n    download_url(urls[split], coco_gt_root)\n    annotation_file = os.path.join(coco_gt_root, filenames[split])\n\n    # create coco object and coco_result object\n    coco = COCO(annotation_file)\n    coco_result = coco.loadRes(results_file)\n\n    # create coco_eval object by taking coco and coco_result\n    coco_eval = COCOEvalCap(coco, coco_result)\n\n    # evaluate on a subset of images by setting\n    # coco_eval.params['image_id'] = coco_result.getImgIds()\n    # please remove this line when evaluating the full validation set\n    # coco_eval.params['image_id'] = coco_result.getImgIds()\n\n    # evaluate results\n    # SPICE will take a few minutes the first time, but speeds up due to caching\n    coco_eval.evaluate()\n\n    # print output evaluation scores\n    for metric, score in coco_eval.eval.items():\n        print(f\"{metric}: {score:.3f}\")\n\n    return coco_eval\n"
  },
  {
    "path": "lavis/tasks/dialogue.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport os\n\nfrom lavis.common.dist_utils import main_process\nfrom lavis.common.logger import MetricLogger\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\nfrom lavis.datasets.data_utils import prepare_sample\n\nimport numpy as np\n\n\n@registry.register_task(\"dialogue\")\nclass DialogueTask(BaseTask):\n    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):\n        super().__init__()\n\n        self.num_beams = num_beams\n        self.max_len = max_len\n        self.min_len = min_len\n        self.evaluate = evaluate\n\n        self.report_metric = report_metric\n\n    @classmethod\n    def setup_task(cls, cfg):\n        run_cfg = cfg.run_cfg\n\n        num_beams = run_cfg.num_beams\n        max_len = run_cfg.max_len\n        min_len = run_cfg.min_len\n        evaluate = run_cfg.evaluate\n\n        report_metric = run_cfg.get(\"report_metric\", True)\n\n        return cls(\n            num_beams=num_beams,\n            max_len=max_len,\n            min_len=min_len,\n            evaluate=evaluate,\n            report_metric=report_metric,\n        )\n\n    def valid_step(self, model, samples):\n        results = []\n        loss = model(samples)[\"loss\"].item()\n\n        return [loss]\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n\n        if self.report_metric:\n            avg_loss = np.mean(val_result)\n            metrics = {\"agg_metrics\": avg_loss}\n        else:\n            metrics = {\"agg_metrics\": 0.0}\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n        # TODO better way to define this\n        coco_gt_root = os.path.join(registry.get_path(\"cache_root\"), \"coco_gt\")\n        coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name)\n\n        agg_metrics = coco_val.eval[\"CIDEr\"] + coco_val.eval[\"Bleu_4\"]\n        log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        coco_res = {k: v for k, v in coco_val.eval.items()}\n        coco_res[\"agg_metrics\"] = agg_metrics\n\n        return coco_res\n\n\n# TODO better structure for this.\nfrom pycocoevalcap.eval import COCOEvalCap\nfrom pycocotools.coco import COCO\nfrom torchvision.datasets.utils import download_url\n\n\ndef coco_dialogue_eval(coco_gt_root, results_file, split):\n\n    urls = {\n        \"val\": \"https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json\",\n        \"test\": \"https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json\",\n    }\n    filenames = {\n        \"val\": \"coco_karpathy_val_gt.json\",\n        \"test\": \"coco_karpathy_test_gt.json\",\n    }\n\n    download_url(urls[split], coco_gt_root)\n    annotation_file = os.path.join(coco_gt_root, filenames[split])\n\n    # create coco object and coco_result object\n    coco = COCO(annotation_file)\n    coco_result = coco.loadRes(results_file)\n\n    # create coco_eval object by taking coco and coco_result\n    coco_eval = COCOEvalCap(coco, coco_result)\n\n    # evaluate on a subset of images by setting\n    # coco_eval.params['image_id'] = coco_result.getImgIds()\n    # please remove this line when evaluating the full validation set\n    # coco_eval.params['image_id'] = coco_result.getImgIds()\n\n    # evaluate results\n    # SPICE will take a few minutes the first time, but speeds up due to caching\n    coco_eval.evaluate()\n\n    # print output evaluation scores\n    for metric, score in coco_eval.eval.items():\n        print(f\"{metric}: {score:.3f}\")\n\n    return coco_eval\n"
  },
  {
    "path": "lavis/tasks/image_text_pretrain.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\n\n\n@registry.register_task(\"image_text_pretrain\")\nclass ImageTextPretrainTask(BaseTask):\n    def __init__(self):\n        super().__init__()\n\n    def evaluation(self, model, data_loader, cuda_enabled=True):\n        pass\n"
  },
  {
    "path": "lavis/tasks/multimodal_classification.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport os\nimport logging\n\nimport numpy as np\nimport torch\nfrom lavis.common.dist_utils import main_process\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\n\n\n@registry.register_task(\"multimodal_classification\")\nclass MultimodalClassificationTask(BaseTask):\n    def __init__(self):\n        super().__init__()\n\n    def valid_step(self, model, samples):\n        results = []\n\n        outputs = model.predict(samples)\n\n        predictions = outputs[\"predictions\"]\n        targets = outputs[\"targets\"]\n\n        predictions = predictions.max(1)[1].cpu().numpy()\n        targets = targets.cpu().numpy()\n\n        indices = samples[self.inst_id_key]\n\n        for pred, tgt, index in zip(predictions, targets, indices):\n            if isinstance(index, torch.Tensor):\n                index = index.item()\n\n            results.append(\n                {\n                    self.inst_id_key: index,\n                    \"prediction\": pred.item(),\n                    \"target\": tgt.item(),\n                }\n            )\n\n        return results\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n        eval_result_file = self.save_result(\n            result=val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=\"{}_epoch{}\".format(split_name, epoch),\n            remove_duplicate=self.inst_id_key,\n        )\n\n        metrics = self._report_metrics(\n            eval_result_file=eval_result_file, split_name=split_name\n        )\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n        results = json.load(open(eval_result_file))\n\n        predictions = np.array([res[\"prediction\"] for res in results])\n        targets = np.array([res[\"target\"] for res in results])\n\n        accuracy = (targets == predictions).sum() / targets.shape[0]\n        metrics = {\"agg_metrics\": accuracy, \"acc\": accuracy}\n\n        log_stats = {split_name: {k: v for k, v in metrics.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        logging.info(metrics)\n        return metrics\n"
  },
  {
    "path": "lavis/tasks/retrieval.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nfrom lavis.common.dist_utils import is_main_process\nfrom lavis.common.registry import registry\nfrom lavis.tasks.base_task import BaseTask\n\n\n@registry.register_task(\"retrieval\")\nclass RetrievalTask(BaseTask):\n    def __init__(self, cfg):\n        super().__init__()\n\n        self.cfg = cfg\n\n    @classmethod\n    def setup_task(cls, cfg):\n        run_cfg = cfg.run_cfg\n\n        return cls(cfg=run_cfg)\n\n    def evaluation(self, model, data_loader, **kwargs):\n        # score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader)\n        score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg)\n\n        if is_main_process():\n            eval_result = self._report_metrics(\n                score_i2t,\n                score_t2i,\n                data_loader.dataset.txt2img,\n                data_loader.dataset.img2txt,\n            )\n            logging.info(eval_result)\n        else:\n            eval_result = None\n\n        return eval_result\n\n    def after_evaluation(self, val_result, **kwargs):\n        return val_result\n\n    @staticmethod\n    @torch.no_grad()\n    def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt):\n\n        # Images->Text\n        ranks = np.zeros(scores_i2t.shape[0])\n        for index, score in enumerate(scores_i2t):\n            inds = np.argsort(score)[::-1]\n            # Score\n            rank = 1e20\n            for i in img2txt[index]:\n                tmp = np.where(inds == i)[0][0]\n                if tmp < rank:\n                    rank = tmp\n            ranks[index] = rank\n\n        # Compute metrics\n        tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)\n        tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)\n        tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)\n\n        # Text->Images\n        ranks = np.zeros(scores_t2i.shape[0])\n\n        for index, score in enumerate(scores_t2i):\n            inds = np.argsort(score)[::-1]\n            ranks[index] = np.where(inds == txt2img[index])[0][0]\n\n        # Compute metrics\n        ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)\n        ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)\n        ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)\n\n        tr_mean = (tr1 + tr5 + tr10) / 3\n        ir_mean = (ir1 + ir5 + ir10) / 3\n        r_mean = (tr_mean + ir_mean) / 2\n\n        agg_metrics = (tr1 + tr5 + tr10) / 3\n\n        eval_result = {\n            \"txt_r1\": tr1,\n            \"txt_r5\": tr5,\n            \"txt_r10\": tr10,\n            \"txt_r_mean\": tr_mean,\n            \"img_r1\": ir1,\n            \"img_r5\": ir5,\n            \"img_r10\": ir10,\n            \"img_r_mean\": ir_mean,\n            \"r_mean\": r_mean,\n            \"agg_metrics\": agg_metrics,\n        }\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(eval_result) + \"\\n\")\n        return eval_result\n"
  },
  {
    "path": "lavis/tasks/vqa.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport json\nimport os\nimport torch\nimport numpy as np\nimport random\n\nimport lavis.common.dist_utils as dist_utils\nfrom lavis.common.registry import registry\nfrom lavis.common.vqa_tools.vqa import VQA\nfrom lavis.common.vqa_tools.vqa_eval import VQAEval\nfrom lavis.tasks.base_task import BaseTask\nfrom lavis.common.dist_utils import main_process\n\n@registry.register_task(\"vqa\")\nclass VQATask(BaseTask):\n    def __init__(\n        self,\n        num_beams,\n        max_len,\n        min_len,\n        evaluate,\n        num_ans_candidates,\n        inference_method=\"rank\",\n        prompt=\"\",\n    ):\n        super().__init__()\n\n        self.num_beams = num_beams\n        self.max_len = max_len\n        self.min_len = min_len\n\n        self.evaluate = evaluate\n        self.inference_method = inference_method\n        self.num_ans_candidates = num_ans_candidates\n        self.prompt = prompt\n\n        self.answer_list = None\n\n        self.ques_files = dict()\n        self.anno_files = dict()\n\n    @classmethod\n    def setup_task(cls, cfg):\n        run_cfg = cfg.run_cfg\n\n        num_beams = run_cfg.get(\"num_beams\", 3)\n        max_len = run_cfg.get(\"max_len\", 10)\n        min_len = run_cfg.get(\"min_len\", 1)\n\n        evaluate = run_cfg.get(\"evaluate\", False)\n\n        inference_method = run_cfg.get(\"inference_method\", \"rank\")\n        num_ans_candidates = run_cfg.get(\"num_ans_candidates\", 128)\n        prompt = run_cfg.get(\"prompt\", \"\")\n\n        return cls(\n            num_beams=num_beams,\n            max_len=max_len,\n            min_len=min_len,\n            evaluate=evaluate,\n            num_ans_candidates=num_ans_candidates,\n            inference_method=inference_method,\n            prompt=prompt,\n        )\n\n    def build_datasets(self, cfg):\n        datasets = super().build_datasets(cfg)\n\n        # get question file, annotation file and anwser list in COCO format\n        for dataset in datasets.values():\n            for split in dataset:\n                if (\n                    hasattr(dataset[split], \"coco_fmt_qust_file\")\n                    and dataset[split].coco_fmt_qust_file is not None\n                ):\n                    self.ques_files[split] = dataset[split].coco_fmt_qust_file\n                    self.anno_files[split] = dataset[split].coco_fmt_anno_file\n\n                try:\n                    self.answer_list = dataset[split].answer_list\n                except AttributeError:\n                    # if answer_list is not provided, then set it to None\n                    pass\n\n        if len(self.ques_files) > 0:\n            assert len(self.ques_files) == len(\n                self.anno_files\n            ), \"Only support one split for evaluation.\"\n\n        return datasets\n\n    def valid_step(self, model, samples):\n        answers = model.predict_answers(\n            samples=samples,\n            answer_list=self.answer_list,\n            inference_method=self.inference_method,\n            num_beams=self.num_beams,\n            max_len=self.max_len,\n            min_len=self.min_len,\n            num_ans_candidates=self.num_ans_candidates,\n            prompt=self.prompt,\n        )\n        pred_qa_pairs = []\n\n        question_id = samples[\"question_id\"]\n        for answer, ques_id in zip(answers, question_id):\n            ques_id = int(ques_id.item())\n            pred_qa_pairs.append({\"question_id\": ques_id, \"answer\": answer})\n\n        return pred_qa_pairs\n\n    def after_evaluation(self, val_result, split_name, **kwargs):\n        result_file = self.save_result(\n            val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=f\"{split_name}_vqa_result\",\n            remove_duplicate=\"question_id\",\n        )\n\n        metrics = self._report_metrics(result_file=result_file, split=split_name)\n\n        return metrics\n\n    @dist_utils.main_process\n    def _report_metrics(self, result_file, split):\n        \"\"\"\n        Use official VQA evaluation script to report metrics.\n        \"\"\"\n        metrics = {}\n\n        if split in self.ques_files and split in self.anno_files:\n            vqa = VQA(self.anno_files[split], self.ques_files[split])\n            vqa_result = vqa.loadRes(\n                resFile=result_file, quesFile=self.ques_files[split]\n            )\n\n            # create vqaEval object by taking vqa and vqaRes\n            # n is precision of accuracy (number of places after decimal), default is 2\n            vqa_scorer = VQAEval(vqa, vqa_result, n=2)\n            logging.info(\"Start VQA evaluation.\")\n            vqa_scorer.evaluate()\n\n            # print accuracies\n            overall_acc = vqa_scorer.accuracy[\"overall\"]\n            metrics[\"agg_metrics\"] = overall_acc\n\n            logging.info(\"Overall Accuracy is: %.02f\\n\" % overall_acc)\n            logging.info(\"Per Answer Type Accuracy is the following:\")\n\n            for ans_type in vqa_scorer.accuracy[\"perAnswerType\"]:\n                logging.info(\n                    \"%s : %.02f\"\n                    % (ans_type, vqa_scorer.accuracy[\"perAnswerType\"][ans_type])\n                )\n                metrics[ans_type] = vqa_scorer.accuracy[\"perAnswerType\"][ans_type]\n\n            with open(\n                os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n            ) as f:\n                f.write(json.dumps(metrics) + \"\\n\")\n\n        return metrics\n\n@registry.register_task(\"gqa\")\nclass GQATask(VQATask):\n    def valid_step(self, model, samples):\n        answers = model.predict_answers(\n            samples=samples,\n            answer_list=self.answer_list,\n            inference_method=self.inference_method,\n            num_beams=self.num_beams,\n            max_len=self.max_len,\n            min_len=self.min_len,\n            num_ans_candidates=self.num_ans_candidates,\n            prompt=self.prompt,\n        )\n        pred_qa_pairs = []\n\n        question_id = samples[\"question_id\"]\n        gt_answers = samples[\"answer\"]\n        \n        for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):\n            ques_id = int(ques_id.item())\n            pred_qa_pairs.append({\"question_id\": ques_id, \"pred_ans\": answer, \"gt_ans\": gt_answer})\n\n        return pred_qa_pairs\n        \n    @dist_utils.main_process\n    def _report_metrics(self, result_file, split):\n        \"\"\"\n        TODO: add other evaluation metrics for GQA\n        \"\"\"\n\n        results = json.load(open(result_file, \"r\"))\n        acc = []\n        vqa_tool = VQAEval()\n\n        for res in results:\n            if res[\"gt_ans\"] is None:\n                # prepare test results for leaderboard evaluation\n                self._save_result_leaderboard(results)\n                return\n\n            gt_ans = res[\"gt_ans\"]\n            pred = res[\"pred_ans\"]\n\n            if self.inference_method == \"generate\":\n                pred = vqa_tool.processPunctuation(pred)\n                pred = vqa_tool.processDigitArticle(pred)\n\n            vqa_acc = 1 if pred == gt_ans else 0\n\n            acc.append(vqa_acc)\n\n        accuracy = sum(acc) / len(acc) * 100\n        metrics = {\"agg_metrics\": accuracy, \"acc\": accuracy}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(metrics) + \"\\n\")\n\n        logging.info(metrics)\n\n        return metrics\n        \n\n@registry.register_task(\"aok_vqa\")\nclass AOKVQATask(VQATask):\n    def valid_step(self, model, samples):\n        answers = model.predict_answers(\n            samples=samples,\n            answer_list=self.answer_list,\n            inference_method=self.inference_method,\n            num_beams=self.num_beams,\n            max_len=self.max_len,\n            min_len=self.min_len,\n            num_ans_candidates=self.num_ans_candidates,\n        )\n\n        pred_qa_pairs = []\n\n        question_id = samples[\"question_id\"]\n        gt_answers = samples[\"direct_answers\"]\n\n        for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):\n            pred_qa_pairs.append(\n                {\"question_id\": ques_id, \"pred_ans\": pred_answer, \"gt_ans\": gt_answer}\n            )\n\n        return pred_qa_pairs\n\n    @dist_utils.main_process\n    def _report_metrics(self, result_file, split):\n        \"\"\"\n        Implementing accuracy computation for AOKVQA, see\n        https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details.\n        \"\"\"\n        # TODO add evaluation for multi-choice\n\n        results = json.load(open(result_file, \"r\"))\n        acc = []\n\n        for res in results:\n            if res[\"gt_ans\"] is None:\n                # prepare test results for leaderboard evaluation\n                self._save_result_leaderboard(results)\n                return\n\n            pred = res[\"pred_ans\"]\n            gt_ans = res[\"gt_ans\"]\n\n            num_match = sum([pred == gt for gt in gt_ans])\n            vqa_acc = min(1.0, num_match / 3.0)\n\n            acc.append(vqa_acc)\n\n        accuracy = sum(acc) / len(acc) * 100\n        metrics = {\"agg_metrics\": accuracy, \"acc\": accuracy}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(metrics) + \"\\n\")\n\n        logging.info(metrics)\n\n        return metrics\n\n    @dist_utils.main_process\n    def _save_result_leaderboard(self, results):\n        \"\"\"\n        Saving the results in the format required for leaderboard evaluation.\n\n        [TODO] add support for multi-choice.\n        \"\"\"\n        result_leaderboard = dict()\n        for res in results:\n            result_leaderboard[res[\"question_id\"]] = {\n                \"direct_answer\": res[\"pred_ans\"],\n                \"multiple_choice\": \"\",\n            }\n\n        result_file = registry.get_path(\"result_dir\") + \"_leaderboard.json\"\n\n        with open(result_file, \"w\") as f:\n            json.dump(result_leaderboard, f)\n\n        logging.info(f\"Saved results for leaderboard evaluation at {result_file}\")\n    \n@registry.register_task(\"frameqa\")\nclass FrameQA(BaseTask):\n    def __init__(self):\n        super().__init__()\n        self.ANS_MAPPING = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}\n\n    def valid_step(self, model, samples):\n        results = []\n\n        outputs = model.generate(samples)\n\n        answer = outputs[\"answer\"]\n        qid = outputs[\"qid\"]\n        output_text = outputs['output_text']\n        temp_idx = outputs['temp_idx']\n        assert len(qid)==len(temp_idx)\n        assert len(qid)==len(output_text)\n        assert len(qid)==len(answer) \n        \n        for a, q, o, i in zip(answer, qid, output_text, temp_idx):\n            # l =  l[self.ANS_MAPPING[a[-1]]]\n            results.append(\n                {\n                    \"qid\": q,\n                    'idx': i,\n                    \"prediction\": o,\n                    \"target\": self.ANS_MAPPING[a[-1]],\n                }\n            )\n\n        return results\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n        eval_result_file = self.save_result(\n            result=val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=\"{}_epoch{}\".format(split_name, epoch)\n        )\n\n        metrics = self._report_metrics(\n            eval_result_file=eval_result_file, split_name=split_name\n        )\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n        results = json.load(open(eval_result_file))\n        total_num = len(results)\n        acc = 0\n        group_by_qid = {}\n        qtype_correct_dict = {}\n        qtype_total_dict = {}\n        for r in results:\n            \n            if r['qid'] not in group_by_qid:\n                group_by_qid[r['qid']] = {} \n                group_by_qid[r['qid']]['idx'] = [r['idx']]\n                group_by_qid[r['qid']]['pred'] = [r['prediction']]\n                group_by_qid[r['qid']]['target'] = r['target']\n            else:\n                group_by_qid[r['qid']]['idx'].append(r['idx'])\n                group_by_qid[r['qid']]['pred'].append(r['prediction'])\n                \n            qtype = r['qid'][0]\n            if qtype not in qtype_total_dict:\n                qtype_total_dict[qtype] = 1\n            else:\n                qtype_total_dict[qtype] += 1 \n\n            if r['prediction'] == r['target']:\n                acc += 1\n                if qtype not in qtype_correct_dict:\n                    qtype_correct_dict[qtype] = 1\n                else:\n                    qtype_correct_dict[qtype] += 1 \n                \n        oracle = 0 \n        num = len(group_by_qid.keys())\n        for q in group_by_qid:\n            if group_by_qid[q]['target'] in group_by_qid[q]['pred']:\n                oracle += 1\n        \n        metrics = {\"agg_metrics\": oracle/num , 'num': num, 'avg_acc': acc/total_num * 100, 'total':total_num}\n        \n        for qtype in qtype_total_dict:\n            metrics[qtype] = qtype_correct_dict[qtype] / qtype_total_dict[qtype] * 100\n\n        log_stats = {split_name: {k: v for k, v in metrics.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        logging.info(metrics)\n        return metrics\n    \n\n@registry.register_task(\"videoqa\")\nclass VideoQA(BaseTask):\n    def __init__(self):\n        super().__init__()\n        self.ANS_MAPPING = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}\n\n    def valid_step(self, model, samples):\n        results = []\n\n        outputs = model.generate(samples)\n\n        answer = outputs[\"answer\"]\n        qid = outputs[\"qid\"]\n        output_text = outputs['output_text']\n        if 'frame_idx' in outputs:\n            frame_idx = outputs['frame_idx']\n        else:\n            frame_idx = [0 for i in range(len(qid))]\n        # print(qid)\n        # print(len(output_text), output_text)\n        assert len(qid)==len(output_text)\n        assert len(qid)==len(answer) \n        \n        for a, q, o, f in zip(answer, qid, output_text, frame_idx):\n            # l =  l[self.ANS_MAPPING[a[-1]]]\n            results.append(\n                {\n                    \"qid\": q,\n                    \"prediction\": o,\n                    \"target\": self.ANS_MAPPING[a[-1]],\n                    \"frame_idx\": f\n                }\n            )\n\n        return results\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n        eval_result_file = self.save_result(\n            result=val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=\"{}_epoch{}\".format(split_name, epoch)\n        )\n\n        metrics = self._report_metrics(\n            eval_result_file=eval_result_file, split_name=split_name\n        )\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n        results = json.load(open(eval_result_file))\n        total_num = len(results)\n        acc = 0\n        qtype_correct_dict = {}\n        qtype_total_dict = {}\n        for r in results:    \n            qtype = r['qid'].split('_')[0]\n            if qtype not in qtype_total_dict:\n                qtype_total_dict[qtype] = 1\n            else:\n                qtype_total_dict[qtype] += 1 \n\n            if r['prediction'] == r['target']:\n                acc += 1\n                if qtype not in qtype_correct_dict:\n                    qtype_correct_dict[qtype] = 1\n                else:\n                    qtype_correct_dict[qtype] += 1 \n        \n        metrics = {\"agg_metrics\": acc/total_num , 'total':total_num}\n        \n        for qtype in qtype_total_dict:\n            metrics[qtype] = qtype_correct_dict[qtype] / qtype_total_dict[qtype] * 100\n            \n        # for STAR\n        if ('Interaction' in metrics) and ('Sequence' in metrics) and ('Prediction' in metrics) and ('Feasibility' in metrics):\n            metrics[\"agg_metrics\"] = (metrics['Interaction'] + metrics['Sequence'] + metrics['Prediction'] + metrics['Feasibility']) / 4\n\n        log_stats = {split_name: {k: v for k, v in metrics.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        logging.info(metrics)\n        return metrics\n    \n    \n@registry.register_task(\"moment_retrieval\")\nclass MR(BaseTask):\n    def __init__(self):\n        super().__init__()\n        self.ANS_MAPPING = {'no': 0, 'yes': 1}\n\n    def valid_step(self, model, samples):\n        results = []\n\n        outputs = model.generate(samples)\n        answer = outputs['answer']\n        qid = outputs['qid']\n        score = outputs['yes_score']\n        pred = outputs['pred_ans']\n        assert len(qid)==len(answer)\n        assert len(qid)==len(score)\n        assert len(qid)==len(pred) \n        \n        i = 0\n        for a, q, s, p in zip(answer, qid, score, pred):\n            # l =  l[self.ANS_MAPPING[a[-1]]]\n            results.append(\n                {\n                    \"qid\": q + '_' + str(i),\n                    \"prediction\": p,\n                    \"target\": self.ANS_MAPPING[a],\n                    'score': s\n                }\n            )\n            i += 1\n\n        return results\n\n    def after_evaluation(self, val_result, split_name, epoch, **kwargs):\n        eval_result_file = self.save_result(\n            result=val_result,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=\"{}_epoch{}\".format(split_name, epoch)\n        )\n\n        metrics = self._report_metrics(\n            eval_result_file=eval_result_file, split_name=split_name\n        )\n\n        return metrics\n\n    @main_process\n    def _report_metrics(self, eval_result_file, split_name):\n        results = json.load(open(eval_result_file))\n        total_num = len(results)\n        acc = 0\n        for r in results:\n            if r['prediction'] == r['target']:\n                acc += 1\n        metrics = {\"agg_metrics\": acc / total_num, 'total': total_num}\n        log_stats = {split_name: {k: v for k, v in metrics.items()}}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(log_stats) + \"\\n\")\n\n        logging.info(metrics)\n        return metrics"
  },
  {
    "path": "lavis/tasks/vqa_reading_comprehension.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport json\nimport os\nimport torch\nimport torch.distributed as dist\nfrom itertools import chain\n\nimport lavis.common.dist_utils as dist_utils\nfrom lavis.common.dist_utils import get_rank, get_world_size, is_main_process\nfrom lavis.common.registry import registry\nfrom lavis.common.vqa_tools.vqa_eval import VQAEval as VQATool\nfrom lavis.tasks.vqa import VQATask\n\n\n@registry.register_task(\"vqa_reading_comprehension\")\nclass VQARCTask(VQATask):\n    def __init__(\n        self,\n        num_beams,\n        max_len,\n        min_len,\n        evaluate,\n        num_ans_candidates,\n        inference_method=\"rank\",\n        **kwargs,\n    ):\n        super().__init__(num_beams, max_len, min_len, evaluate, num_ans_candidates, inference_method)\n\n        self.config = kwargs.get('config')\n\n    @classmethod\n    def setup_task(cls, cfg):\n        run_cfg = cfg.run_cfg\n\n        num_beams = run_cfg.get(\"num_beams\", 3)\n        max_len = run_cfg.get(\"max_len\", 10)\n        min_len = run_cfg.get(\"min_len\", 1)\n\n        evaluate = run_cfg.get(\"evaluate\", False)\n\n        inference_method = run_cfg.get(\"inference_method\", \"rank\")\n        num_ans_candidates = run_cfg.get(\"num_ans_candidates\", 128)\n\n        return cls(\n            num_beams=num_beams,\n            max_len=max_len,\n            min_len=min_len,\n            evaluate=evaluate,\n            num_ans_candidates=num_ans_candidates,\n            inference_method=inference_method,\n            config=run_cfg,\n        )\n\n    def valid_step(self, model, samples):\n        answers, captions, gradcams = model.predict_answers(\n            samples=samples,\n            inference_method=self.inference_method,\n            num_beams=self.num_beams,\n            max_len=self.max_len,\n            min_len=self.min_len,\n            internal_bsz_fid=self.config['internal_bsz_fid'],\n            num_captions=self.config['num_captions'],\n            num_captions_fid=self.config['num_captions_fid'],\n            cap_max_length=self.config['cap_max_length'],\n            cap_min_length=self.config['cap_min_length'],\n            top_k=self.config['top_k'],\n            top_p=self.config['top_p'],\n            repetition_penalty=self.config['repetition_penalty'],\n            num_patches=self.config['num_patches'],\n            block_num=self.config['block_num'],\n        )\n\n        pred_qa_pairs = []\n        sample_captions = []\n        sample_gradcams = []\n\n        question_id = samples[\"question_id\"]\n        for answer, caption, gradcam, ques_id in zip(answers, captions, gradcams, question_id):\n            ques_id = int(ques_id.item())\n            pred_qa_pairs.append({\"question_id\": ques_id, \"answer\": answer})\n            sample_captions.append({\"question_id\": ques_id, \"caption\": caption})\n            sample_gradcams.append({\"question_id\": ques_id, \"gradcam\": gradcam})\n\n        return [sample_gradcams, sample_captions, pred_qa_pairs]\n\n    def after_evaluation(self, val_result, split_name, **kwargs):\n        result_ = list(chain(*val_result[0::3]))\n        result_file = self.save_gradcam(\n            result_,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=f\"{split_name}_gradcam_result\",\n            remove_duplicate=\"question_id\",\n        )\n\n        result_ = list(chain(*val_result[1::3]))\n        result_file = self.save_result(\n            result_,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=f\"{split_name}_caption_result\",\n            remove_duplicate=\"question_id\",\n        )\n\n        result_ = list(chain(*val_result[2::3]))\n        result_file = self.save_result(\n            result_,\n            result_dir=registry.get_path(\"result_dir\"),\n            filename=f\"{split_name}_vqa_result\",\n            remove_duplicate=\"question_id\",\n        )\n\n        metrics = self._report_metrics(result_file=result_file, split=split_name)\n\n        return metrics\n\n    def save_gradcam(self, result, result_dir, filename, remove_duplicate=\"\"):\n        result_file = os.path.join(result_dir, '%s_rank%d.pth' % (filename, get_rank()))\n        final_result_file = os.path.join(result_dir, '%s.pth' % filename)\n        torch.save({'result': result}, result_file)\n\n        dist.barrier()\n\n        if is_main_process():\n            logging.warning(\"rank %d starts merging results.\" % get_rank())\n            # combine results from all processes\n            result = []\n\n            for rank in range(get_world_size()):\n                result_file = os.path.join(result_dir, '%s_rank%d.pth' % (filename, rank))\n                res_ckpt = torch.load(result_file, map_location='cpu')\n                res = res_ckpt['result']\n\n                result += res\n\n            if remove_duplicate:\n                result_new = []\n                id_list = []\n                for res in result:\n                    if res[remove_duplicate] not in id_list:\n                        id_list.append(res[remove_duplicate])\n                        result_new.append(res)\n                result = result_new\n\n            torch.save({'result': result}, final_result_file)\n            print(\"result file saved to %s\" % final_result_file)\n\n        return final_result_file\n\n\n@registry.register_task(\"gqa_reading_comprehension\")\nclass GQARCTask(VQARCTask):\n    def valid_step(self, model, samples):\n        answers, captions, gradcams = model.predict_answers(\n            samples=samples,\n            inference_method=self.inference_method,\n            num_beams=self.num_beams,\n            max_len=self.max_len,\n            min_len=self.min_len,\n            internal_bsz_fid=self.config['internal_bsz_fid'],\n            num_captions=self.config['num_captions'],\n            num_captions_fid=self.config['num_captions_fid'],\n            cap_max_length=self.config['cap_max_length'],\n            cap_min_length=self.config['cap_min_length'],\n            top_k=self.config['top_k'],\n            top_p=self.config['top_p'],\n            repetition_penalty=self.config['repetition_penalty'],\n            num_patches=self.config['num_patches'],\n            block_num=self.config['block_num'],\n        )\n\n        pred_qa_pairs = []\n        sample_captions = []\n        sample_gradcams = []\n\n        question_id = samples[\"question_id\"]\n        gt_answers = samples[\"answer\"]\n\n        for pred_answer, caption, gradcam, ques_id, gt_answer in zip(answers, captions, gradcams, question_id, gt_answers):\n            ques_id = int(ques_id.item())\n            pred_qa_pairs.append({\"question_id\": ques_id, \"pred_ans\": pred_answer, \"gt_ans\": gt_answer})\n            sample_captions.append({\"question_id\": ques_id, \"caption\": caption})\n            sample_gradcams.append({\"question_id\": ques_id, \"gradcam\": gradcam})\n\n        return [sample_gradcams, sample_captions, pred_qa_pairs]\n\n    @dist_utils.main_process\n    def _report_metrics(self, result_file, split):\n        \"\"\"\n        TODO: add other evaluation metrics for GQA\n        \"\"\"\n\n        results = json.load(open(result_file, \"r\"))\n        acc = []\n        vqa_tool = VQATool()\n\n        for res in results:\n            if res[\"gt_ans\"] is None:\n                # prepare test results for leaderboard evaluation\n                self._save_result_leaderboard(results)\n                return\n\n            gt_ans = res[\"gt_ans\"]\n            pred = res[\"pred_ans\"]\n\n            if self.inference_method == \"generate\":\n                pred = vqa_tool.processPunctuation(pred)\n                pred = vqa_tool.processDigitArticle(pred)\n\n            vqa_acc = 1 if pred == gt_ans else 0\n\n            acc.append(vqa_acc)\n\n        accuracy = sum(acc) / len(acc) * 100\n        metrics = {\"agg_metrics\": accuracy, \"acc\": accuracy}\n\n        with open(\n            os.path.join(registry.get_path(\"output_dir\"), \"evaluate.txt\"), \"a\"\n        ) as f:\n            f.write(json.dumps(metrics) + \"\\n\")\n\n        logging.info(metrics)\n\n        return metrics\n\n    @dist_utils.main_process\n    def _save_result_leaderboard(self, results):\n        \"\"\"\n        Saving the results in the format required for leaderboard evaluation.\n        \"\"\"\n        result_leaderboard = []\n        for res in results:\n            result_leaderboard.append({\n                \"questionId\": str(res['question_id']),\n                \"prediction\": str(res[\"pred_ans\"]),\n            })\n\n        result_file = registry.get_path(\"result_dir\") + \"_leaderboard.json\"\n\n        with open(result_file, \"w\") as f:\n            json.dump(result_leaderboard, f)\n\n        logging.info(f\"Saved results for leaderboard evaluation at {result_file}\")\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires      = [\"setuptools>=61.0.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "contexttimer\ndecord\neinops>=0.4.1\nfairscale==0.4.4\nftfy\niopath\nipython\nomegaconf\nopencv-python-headless==4.5.5.64\nopendatasets\npackaging\npandas\nplotly\npre-commit\npycocoevalcap\npycocotools\npython-magic\nscikit-image\nsentencepiece\nspacy\nstreamlit\ntimm==0.4.12\ntorch>=1.10.0\ntorchvision\ntqdm\ntransformers>=4.25.0\nwebdataset\nwheel\n"
  },
  {
    "path": "run_scripts/sevila/finetune/nexqa_ft.sh",
    "content": "# parameters\nresult_dir=\"\"\n\nexp_name='nextqa_ft'\nckpt='sevila_checkpoints/sevila_pretrained.pth'\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 train.py \\\n--cfg-path lavis/projects/sevila/train/nextqa.yaml \\\n--options run.output_dir=${result_dir}${exp_name} \\\nmodel.frame_num=4 \\\ndatasets.nextqa.vis_processor.train.n_frms=32 \\\ndatasets.nextqa.vis_processor.eval.n_frms=32 \\\nrun.batch_size_train=8 \\\nrun.batch_size_eval=8 \\\nrun.init_lr=3e-5 \\\nrun.max_epoch=10 \\\nrun.warmup_steps=1000 \\\nrun.accum_grad_iters=2 \\\nmodel.task='qvh_freeze_loc_train_qa_with_loc_train_qa_vid' \\\nmodel.finetuned=${ckpt} \\\nrun.task='videoqa'"
  },
  {
    "path": "run_scripts/sevila/inference/nexqa_infer.sh",
    "content": "# parameters/data path\nresult_dir=\"\"\n\nexp_name='nextqa_infer'\nckpt='sevila_checkpoints/sevila_pretrained.pth'\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 evaluate.py \\\n--cfg-path lavis/projects/sevila/eval/nextqa_eval.yaml \\\n--options run.output_dir=${result_dir}${exp_name} \\\nmodel.frame_num=4 \\\ndatasets.nextqa.vis_processor.eval.n_frms=32 \\\nrun.batch_size_eval=8 \\\nmodel.task='qvh_freeze_loc_freeze_qa_vid' \\\nmodel.finetuned=${ckpt} \\\nrun.task='videoqa'"
  },
  {
    "path": "run_scripts/sevila/pre-train/pretrain_qvh.sh",
    "content": "result_dir=\"/nas-hdd/shoubin/result/\"\n\nexp_name='qvh_pretraining'\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 train.py \\\n--cfg-path lavis/projects/sevila/train/qvh.yaml \\\n--options run.output_dir=${result_dir}${exp_name} \\\nmodel.frame_num=4 \\\ndatasets.qvh.vis_processor.train.n_frms=4 \\\ndatasets.qvh.vis_processor.eval.n_frms=75 \\\nrun.batch_size_train=16 \\\nrun.batch_size_eval=4 \\\nrun.init_lr=3e-5 \\\nrun.max_epoch=80 \\\nrun.warmup_steps=1000 \\\nrun.accum_grad_iters=1 \\\nrun.task='moment_retrieval'"
  },
  {
    "path": "run_scripts/sevila/refinement/nexqa_sr.sh",
    "content": "# parameters\nresult_dir=\"\"\n\nexp_name='nextqa_sr'\nckpt='sevila_checkpoints/sevila_pretrained.pth'\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 train.py \\\n--cfg-path lavis/projects/sevila/train/nextqa.yaml \\\n--options run.output_dir=${result_dir}${exp_name} \\\nmodel.frame_num=4 \\\ndatasets.nextqa.vis_processor.train.n_frms=4 \\\ndatasets.nextqa.vis_processor.eval.n_frms=32 \\\nrun.batch_size_train=16 \\\nrun.batch_size_eval=12 \\\nrun.init_lr=3e-5 \\\nrun.max_epoch=10 \\\nrun.warmup_steps=500 \\\nrun.accum_grad_iters=1 \\\nmodel.task='train_loc_freeze_qa_vid' \\\nmodel.finetuned=${ckpt} \\\nrun.task='videoqa'"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom setuptools import setup, find_namespace_packages\nimport platform\n\nDEPENDENCY_LINKS = []\nif platform.system() == \"Windows\":\n    DEPENDENCY_LINKS.append(\"https://download.pytorch.org/whl/torch_stable.html\")\n\n\ndef fetch_requirements(filename):\n    with open(filename) as f:\n        return [ln.strip() for ln in f.read().split(\"\\n\")]\n\n\nsetup(\n    name=\"salesforce-lavis\",\n    version=\"1.0.0.dev1\",\n    author=\"Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi\",\n    description=\"LAVIS - A One-stop Library for Language-Vision Intelligence\",\n    long_description=open(\"README.md\", \"r\", encoding=\"utf-8\").read(),\n    long_description_content_type=\"text/markdown\",\n    keywords=\"Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch\",\n    license=\"3-Clause BSD\",\n    packages=find_namespace_packages(include=\"lavis.*\"),\n    install_requires=fetch_requirements(\"requirements.txt\"),\n    python_requires=\">=3.7.0\",\n    include_package_data=True,\n    dependency_links=DEPENDENCY_LINKS,\n    zip_safe=False,\n)\n"
  },
  {
    "path": "sevila_checkpoints/__init__.py",
    "content": ""
  },
  {
    "path": "sevila_data/Data Preprocess.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4204c7e1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5e807f09\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# create folder for each dataset first    \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"09845339\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def save_json(content, save_path):\\n\",\n    \"    with open(save_path, 'w') as f:\\n\",\n    \"        f.write(json.dumps(content))\\n\",\n    \"def load_jsonl(filename):\\n\",\n    \"    with open(filename, \\\"r\\\") as f:\\n\",\n    \"        return [json.loads(l.strip(\\\"\\\\n\\\")) for l in f.readlines()]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2edfddc1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# nextqa\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9d49722d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"raw_train_csv = 'train.csv'\\n\",\n    \"raw_val_csv = 'val.csv'\\n\",\n    \"raw_train = pd.read_csv(raw_train_csv, delimiter=',')\\n\",\n    \"raw_val = pd.read_csv(raw_val_csv, delimiter=',')\\n\",\n    \"train = []\\n\",\n    \"val = []\\n\",\n    \"key = ['video', 'question', 'a0', 'a1', 'a2', 'a3', 'a4', 'answer', 'qid', 'type'] \\n\",\n    \"for i in range(len(raw_train)):\\n\",\n    \"    data = {}\\n\",\n    \"    for k in key:\\n\",\n    \"        data[k] = raw_train.iloc[i][k]\\n\",\n    \"    train.append(data)\\n\",\n    \"\\n\",\n    \"for i in range(len(raw_val)):\\n\",\n    \"    data = {}\\n\",\n    \"    for k in key:\\n\",\n    \"        data[k] = raw_val.iloc[i][k]\\n\",\n    \"    val.append(data) \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"943e9523\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vid_map = json.load(open('map_vid_vidorID.json'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"afe73814\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"for qa in train:\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = vid_map[str(qa['video'])]\\n\",\n    \"    qa_dict['num_option'] = 5\\n\",\n    \"    qa_dict['qid'] = '_'.join([qa['type'], str(qa['video']), str(qa['qid'])])\\n\",\n    \"    for i in range(5):\\n\",\n    \"        qa_dict['a{}'.format(str(i))] = qa['a{}'.format(str(i))]+'.'\\n\",\n    \"    qa_dict['answer'] = qa['answer']\\n\",\n    \"    qa_dict['question'] = qa['question']+'?'\\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for qa in val:\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = vid_map[str(qa['video'])]\\n\",\n    \"    qa_dict['num_option'] = 5\\n\",\n    \"    qa_dict['qid'] = '_'.join([qa['type'], str(qa['video']), str(qa['qid'])])\\n\",\n    \"    for i in range(5):\\n\",\n    \"        qa_dict['a{}'.format(str(i))] = qa['a{}'.format(str(i))]+'.'\\n\",\n    \"    qa_dict['answer'] = qa['answer']\\n\",\n    \"    qa_dict['question'] = qa['question']+'?'\\n\",\n    \"    new_val.append(qa_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"218a75d3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'nextqa/train.json')\\n\",\n    \"save_json(new_val, 'nextqa/val.json')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"67638a5e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# STAR\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fed28d5a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_path = 'STAR_train.json'\\n\",\n    \"val_path = 'STAR_val.json'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"71918325\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train = json.load(open(train_path))\\n\",\n    \"val = json.load(open(val_path))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"209c3b2a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"for qa in train:\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['video_id']\\n\",\n    \"    qa_dict['num_option'] = 4\\n\",\n    \"    qa_dict['qid'] = qa['question_id']\\n\",\n    \"    for i, choice in enumerate(qa['choices']):\\n\",\n    \"        qa_dict['a{}'.format(str(i))] = choice['choice']\\n\",\n    \"        if choice['choice'] == qa['answer']:\\n\",\n    \"            answer = i\\n\",\n    \"    qa_dict['answer'] = answer\\n\",\n    \"    qa_dict['question'] = qa['question']\\n\",\n    \"    qa_dict['start'] = qa['start']\\n\",\n    \"    qa_dict['end'] = qa['end']\\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for qa in val:\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['video_id']\\n\",\n    \"    qa_dict['num_option'] = 4\\n\",\n    \"    qa_dict['qid'] = qa['question_id']\\n\",\n    \"    for i, choice in enumerate(qa['choices']):\\n\",\n    \"        qa_dict['a{}'.format(str(i))] = choice['choice']\\n\",\n    \"        if choice['choice'] == qa['answer']:\\n\",\n    \"            answer = i\\n\",\n    \"    qa_dict['answer'] = answer\\n\",\n    \"    qa_dict['question'] = qa['question']\\n\",\n    \"    qa_dict['start'] = qa['start']\\n\",\n    \"    qa_dict['end'] = qa['end']\\n\",\n    \"    new_val.append(qa_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e6ced28c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'star/train.json')\\n\",\n    \"save_json(new_val, 'star/val.json')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c72d66f0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# How2QA\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9ab388e9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_path = 'how2qa_train_release.jsonl'\\n\",\n    \"val_path = 'how2qa_val_release.jsonl'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"81bdadb9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train = load_jsonl(train_path)\\n\",\n    \"val = load_jsonl(val_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5164d95e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"for i, qa in enumerate(train):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 4\\n\",\n    \"    qa_dict['qid'] = 'HOW2QA_' + str(i)\\n\",\n    \"    for j in range(4):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['a{}'.format(str(j))]\\n\",\n    \"        \\n\",\n    \"    qa_dict['answer'] = qa['answer_idx']\\n\",\n    \"    qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'].split('-')[0]\\n\",\n    \"    qa_dict['end'] = qa['ts'].split('-')[1]\\n\",\n    \"        \\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(val):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 4\\n\",\n    \"    qa_dict['qid'] = 'HOW2QA_' + str(i)\\n\",\n    \"    for j in range(4):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['a{}'.format(str(j))]\\n\",\n    \"        \\n\",\n    \"    qa_dict['answer'] = qa['answer_idx']\\n\",\n    \"    qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'].split('-')[0]\\n\",\n    \"    qa_dict['end'] = qa['ts'].split('-')[1]\\n\",\n    \"        \\n\",\n    \"    new_val.append(qa_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2daa2a4a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'how2qa/train.json')\\n\",\n    \"save_json(new_val, 'how2qa/val.json')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4e569c5b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# TVQA\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"319d0fb5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_path = 'tvqa_train.jsonl'\\n\",\n    \"val_path = 'tvqa_val.jsonl'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e0f498f5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train = load_jsonl(train_path)\\n\",\n    \"val = load_jsonl(val_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7c24578a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(train):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 5\\n\",\n    \"    qa_dict['qid'] = 'TVQA_' + str(i)\\n\",\n    \"    for j in range(5):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['a{}'.format(str(j))]\\n\",\n    \"    qa_dict['answer'] = qa['answer_idx']\\n\",\n    \"    qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'].split('-')[0]\\n\",\n    \"    qa_dict['end'] = qa['ts'].split('-')[1]\\n\",\n    \"        \\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(val):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 5\\n\",\n    \"    qa_dict['qid'] = 'TVQA_' + str(i)\\n\",\n    \"    for j in range(5):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['a{}'.format(str(j))]\\n\",\n    \"    qa_dict['answer'] = qa['answer_idx']\\n\",\n    \"    qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'].split('-')[0]\\n\",\n    \"    qa_dict['end'] = qa['ts'].split('-')[1]\\n\",\n    \"        \\n\",\n    \"    new_val.append(qa_dict)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"348cafde\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'tvqa/train.json')\\n\",\n    \"save_json(new_val, 'tvqa/val.json')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"52259cd2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# VLPE\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"53646ebf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_path = 'vlep_train_release.jsonl'\\n\",\n    \"val_path = 'vlep_dev_release.jsonl'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ff92c404\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train = load_jsonl(train_path)\\n\",\n    \"val = load_jsonl(val_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"dd62a11e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(train):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 2\\n\",\n    \"    qa_dict['qid'] = 'VLEP_' + str(qa['example_id'])\\n\",\n    \"\\n\",\n    \"    for j in range(2):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['events'][j]\\n\",\n    \"    qa_dict['answer'] = qa['answer']\\n\",\n    \"    # qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'][0]\\n\",\n    \"    qa_dict['end'] = qa['ts'][1]\\n\",\n    \"    \\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(val):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid_name']\\n\",\n    \"    qa_dict['num_option'] = 2\\n\",\n    \"    qa_dict['qid'] = 'VLEP_' + str(qa['example_id'])\\n\",\n    \"\\n\",\n    \"    for j in range(2):\\n\",\n    \"        qa_dict['a{}'.format(str(j))] = qa['events'][j]\\n\",\n    \"    qa_dict['answer'] = qa['answer']\\n\",\n    \"    # qa_dict['question'] = qa['q']\\n\",\n    \"    qa_dict['start'] = qa['ts'][0]\\n\",\n    \"    qa_dict['end'] = qa['ts'][1]\\n\",\n    \"        \\n\",\n    \"    new_val.append(qa_dict)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"186de10a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'vlep/train.json')\\n\",\n    \"save_json(new_val, 'vlep/val.json')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f7083bf5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# qvh\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"1a00480e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_path = 'highlight_train_release.jsonl'\\n\",\n    \"val_path = 'highlight_val_release.jsonl'\\n\",\n    \"test_path = 'highlight_test_release.jsonl'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8cc2f260\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train = load_jsonl(train_path)\\n\",\n    \"val = load_jsonl(val_path)\\n\",\n    \"test = load_jsonl(test_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"507365fc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"new_train = []\\n\",\n    \"new_val = []\\n\",\n    \"new_test = []\\n\",\n    \"for i, qa in enumerate(train):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid']\\n\",\n    \"    qa_dict['qid'] = 'QVHighlight_' + str(qa['qid'])\\n\",\n    \"    qa_dict['query'] = qa['query']\\n\",\n    \"    qa_dict['duration'] = qa['duration']\\n\",\n    \"    qa_dict['relevant_windows'] = qa['relevant_windows']\\n\",\n    \"    new_train.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(val):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid']\\n\",\n    \"    qa_dict['qid'] = 'QVHighlight_' + str(qa['qid'])\\n\",\n    \"    qa_dict['query'] = qa['query']\\n\",\n    \"    qa_dict['duration'] = qa['duration']\\n\",\n    \"    qa_dict['relevant_windows'] = qa['relevant_windows']\\n\",\n    \"    new_val.append(qa_dict)\\n\",\n    \"\\n\",\n    \"for i, qa in enumerate(test):\\n\",\n    \"    qa_dict = {}\\n\",\n    \"    qa_dict['video'] = qa['vid']\\n\",\n    \"    qa_dict['qid'] = 'QVHighlight_' + str(qa['qid'])\\n\",\n    \"    qa_dict['query'] = qa['query']\\n\",\n    \"    qa_dict['duration'] = qa['duration']\\n\",\n    \"    new_test.append(qa_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"0f9754fc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_json(new_train, 'qvh/train.json')\\n\",\n    \"save_json(new_val, 'qvh/val.json')\\n\",\n    \"save_json(new_test, 'qvh/test.json')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.13\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "sevila_data/README.md",
    "content": "# Self-Chained Image-Language Model for Video Localization and Question Answering\n\n\n## Dataset Preparation\nWe test our model on:\n+ [NExT-QA](https://doc-doc.github.io/docs/nextqa.html)\n\n+ [STAR](https://star.csail.mit.edu/)\n\n+ [How2QA](https://value-benchmark.github.io/index.html)\n\n+ [TVQA](https://tvqa.cs.unc.edu/)\n\n+ [VLEP](https://value-benchmark.github.io/index.html)\n\n+ [QVHighlights](https://github.com/jayleicn/moment_detr)\n\nWe re-format original json/csv/jsonl files in different dataset to the same json format via jupyter script.\n\nPlease set your own dataset/video path in running scripts or in dataset config files. For example:\n\n* Option 1: change in running scripts\n\n```bash\nresult_dir=\"YOUR_PATH\"\ntrain_path=\"YOUR_PATH\"\nval_path=\"YOUR_PATH\"\nvideo_path=\"YOUR_PATH\"\n\nexp_name='nextqa_infer'\nckpt='sevila_checkpoints/sevila_pretrained.pth'\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 evaluate.py \\\n--cfg-path lavis/projects/sevila/eval/nextqa_eval.yaml \\\n--options run.output_dir=${result_dir}${exp_name} \\\ndatasets.nextqa.build_info.annotations.train.storage=${train_path} \\\ndatasets.nextqa.build_info.annotations.val.storage=${val_path} \\\ndatasets.nextqa.build_info.annotations.test.storage=${val_path} \\\ndatasets.nextqa.build_info.videos.storage=${video_path} \\\nmodel.frame_num=4 \\\ndatasets.nextqa.vis_processor.eval.n_frms=32 \\\nrun.batch_size_eval=8 \\\nmodel.task='qvh_freeze_loc_freeze_qa_vid' \\\nmodel.finetuned=${ckpt} \\\nrun.task='videoqa'\n\n```\n\n* Option 2: change in dataset config file:\n\nchange [config files](../lavis/configs/datasets/nextqa/defaults_qa.yaml)\n\n\n\n\n\n"
  },
  {
    "path": "train.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\n\nimport lavis.tasks as tasks\nfrom lavis.common.config import Config\nfrom lavis.common.dist_utils import get_rank, init_distributed_mode\nfrom lavis.common.logger import setup_logger\nfrom lavis.common.optims import (\n    LinearWarmupCosineLRScheduler,\n    LinearWarmupStepLRScheduler,\n)\nfrom lavis.common.registry import registry\nfrom lavis.common.utils import now\n\n# imports modules for registration\nfrom lavis.datasets.builders import *\nfrom lavis.models import *\nfrom lavis.processors import *\nfrom lavis.runners import *\nfrom lavis.tasks import *\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Training\")\n\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n\n    args = parser.parse_args()\n    # if 'LOCAL_RANK' not in os.environ:\n    #     os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef setup_seeds(config):\n    seed = config.run_cfg.seed + get_rank()\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n    cudnn.benchmark = False\n    cudnn.deterministic = True\n\n\ndef get_runner_class(cfg):\n    \"\"\"\n    Get runner class from config. Default to epoch-based runner.\n    \"\"\"\n    runner_cls = registry.get_runner_class(cfg.run_cfg.get(\"runner\", \"runner_base\"))\n\n    return runner_cls\n\n\ndef main():\n    # allow auto-dl completes on main process without timeout when using NCCL backend.\n    # os.environ[\"NCCL_BLOCKING_WAIT\"] = \"1\"\n\n    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.\n    job_id = now()\n\n    cfg = Config(parse_args())\n\n    init_distributed_mode(cfg.run_cfg)\n\n    setup_seeds(cfg)\n\n    # set after init_distributed_mode() to only log on master.\n    setup_logger()\n\n    cfg.pretty_print()\n\n    task = tasks.setup_task(cfg)\n    datasets = task.build_datasets(cfg)\n    model = task.build_model(cfg)\n\n    runner = get_runner_class(cfg)(\n        cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets\n    )\n    runner.train()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]