[
  {
    "path": ".gitignore",
    "content": "*.err\n*.out\n*.pyc\nwandb\n/data_preparation/vis_results/\n/data_preparation/vis_results_new/\n/LLAVA_Stage1_Pretrained/\n/work_dirs/\n/llava.egg-info/\n/data_preparation/data/\n/vis_results/\nmodel_worker*\n/playground/\n*.jsonl\n*.pth\ngradio_demo/tmp_files\nllava_bench_results\nsymmary_results\neval_gpt4\nvis_results_pdf_precision\nvis_results_pdf_recall\noutput/\ndatasets/\noutput\ndatasets\n*.log\n*.json\n__pycache__/\n*/__pycache__\n*/*/__pycache__\n*/*/*/__pycache__\n*/*/*/*/__pycache__\ngradio_demo/examples/*.mp4\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models\n========\n\n[[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)]  [[Demo](https://llava-grounding.deepdataspace.com/\n)]  [[Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md)] \n<!-- [[`Paper`](xxx)] [[`BibTex`](#black_nib-citation)] -->\n\n## :fire: News\n[2024/1/14] Our training code is released.\n\n[2023/12/6] Our paper is available in arxiv.\n\n\n## Contents\n- [🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models](#-llava-grounding-grounded-visual-chat-with-large-multimodal-models)\n  - [:fire: News](#fire-news)\n  - [Contents](#contents)\n    - [Install](#install)\n    - [LLaVA-Grounding Weights](#llava-grounding-weights)\n    - [Demo](#demo)\n    - [Training data](#training-data)\n      - [Flickr30k](#flickr30k)\n      - [COCO](#coco)\n      - [LLaVA](#llava)\n    - [Training](#training)\n    - [Citation](#citation)\n\n### Install\n1. Clone this repository and navigate to LLaVA-Grounding fold:\n```shell\ngit clone https://github.com/UX-Decoder/LLaVA-Grounding.git\ncd LLaVA-Grounding\n```\n2. Install required packages:\n```\nconda create -n llava python=3.10 -y\nconda activate llava\npip install --upgrade pip  # enable PEP 660 support\npip install -e .\n```\n\n3. Install additional packages for training cases\n```\npip install -e \".[train]\"\npip install flash-attn --no-build-isolation\n```\n4. Install packages necessary for [OpenSeeD](https://github.com/IDEA-Research/OpenSeeD) and [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM).\n\n### LLaVA-Grounding Weights\nPlease check out our [Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md) for all public LLaVA-Grounding checkpoints, and the instructions on how to use the weights.\n### Demo\nAfter downloading model weights, simply conduct the following commends to run demo on your own machine.\n```shell\nCUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg path_to_vision_cfg --path_inter_cfg path_to_inter_cfg --model_path path_to_ckpt_dir\n\n# for example, after downloading weights into checkpoints/llava_grounding\nCUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml --path_inter_cfg configs/semsam/visual_prompt_encoder.yaml --model_path checkpoints/llava_grounding\n```\n\nPlease refer to our [Online Demo](https://llava-grounding.deepdataspace.com/) for the more detailed user's guidence.\n### Training data\n```text\ndata\n├── flickr30k_entities\n│   ├── train/\n│   ├── val/\n│   ├── annotations\n│          ├──final_flickr_separateGT_train.json\n│          ├──final_flickr_separateGT_val.json\n├── coco\n│   ├── train2014/\n│   ├── train2017/\n│   ├── panoptic_train2017/\n│   ├── panoptic_semseg_train2017/\n│   ├── annotations\n│   │      ├──instances_train2017.json\n│   │      ├──instances_train2017_gvc.json\n│   │      ├──grounded_visual_chat_data.json\n│   │      ├──instances_train2014_filter.json\n│   │      ├──panoptic_train2017_filter.json\n│   │      ├──grounding_train2017.json\n├── llava\n│   ├── annotations\n│          ├── cap600k_brackets_all.json\n│          ├── llava_instruct_150k.json\n│          ├── llava_instruct_150k_visual_prompt.json\n\n```\n#### Flickr30k\nPlease refer to [MDETR's pre-processed flickr30k data](https://github.com/ashkamath/mdetr/blob/main/.github/flickr.md).\n#### COCO\nPlease download coco train2014 and train2017 images and panoptic segmentation and semantic segmentation data. Other annoations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data).\n#### LLaVA\nThe processed annotations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data).\n### Training\nStage 1\n```shell\nbash scripts/pretrain_joint.py\n```\nStage 2\n```shell\nbash scripts/finetune.py\n```\nStage 3\n```shell\nbash scripts/finetune_visual_prompt.py\n```\n### Citation\nIf you find LLaVA-Grounding useful for your research and applications, please cite using this BibTeX:\n```bibtex\n\n@misc{zhang2023llavagrounding,\n      title={LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models},\n      author={Hao Zhang and Hongyang Li and Feng Li and Tianhe Ren and Xueyan Zou and Shilong Liu and Shijia Huang and Jianfeng Gao and Lei Zhang and Chunyuan Li and Jianwei Yang},\n      year={2023},\n      booktitle={arXiv}\n}\n\n@misc{liu2023llava,\n      title={Visual Instruction Tuning}, \n      author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},\n      publisher={arXiv:2304.08485},\n      year={2023}\n}\n```\n"
  },
  {
    "path": "configs/openseed/openseed_swint_lang_joint.yaml",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n##################\n# Task settings\n##################\nWEIGHT: ''\nPORT: 53711\nVERBOSE: true\n#OUTPUT_DIR: '../../data/output/test'\ninference_only: true\nOUTPUT_DIR: '../../data/output/test'\nclip: true\n# misc\nLOADER:\n  JOINT: True\n  KEY_DATASET: 'flickr'\n# model\nMODEL:\n  NAME: openseed_model\n  HEAD: openseed_head\n  MASK_ON: false\n  KEYPOINT_ON: false\n  LOAD_PROPOSALS: false\n  DIM_PROJ: 4096\n  BACKBONE_DIM: 768\n  BACKGROUND: False\n  WEIGHTS: ''\n  TEXT:\n    ARCH: encoder\n    NAME: transformer\n    TOKENIZER: clip\n    CONTEXT_LENGTH: 18 # 18\n    WIDTH: 512\n    HEADS: 8\n    LAYERS: 12\n    AUTOGRESSIVE: True\n  BACKBONE:\n    NAME: swin\n    PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'\n    LOAD_PRETRAINED: true\n    SWIN:\n      PRETRAIN_IMG_SIZE: 224\n      PATCH_SIZE: 4\n      EMBED_DIM: 96\n      DEPTHS: [ 2, 2, 6, 2 ]\n      NUM_HEADS: [ 3, 6, 12, 24 ]\n      WINDOW_SIZE: 7\n      MLP_RATIO: 4.0\n      QKV_BIAS: true\n      QK_SCALE: ~\n      DROP_RATE: 0.0\n      ATTN_DROP_RATE: 0.0\n      DROP_PATH_RATE: 0.3\n      APE: false\n      PATCH_NORM: true\n      USE_CHECKPOINT: false\n      OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ]\n  ENCODER:\n    NAME: encoder_deform\n    IGNORE_VALUE: 255\n    NUM_CLASSES: 133\n    LOSS_WEIGHT: 1.0\n    CONVS_DIM: 256\n    MASK_DIM: 256\n    NORM: \"GN\"\n    IN_FEATURES: [ \"res2\", \"res3\", \"res4\", \"res5\" ]\n    DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ \"res3\", \"res4\", \"res5\" ]\n    COMMON_STRIDE: 4\n    TRANSFORMER_ENC_LAYERS: 6\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    NUM_FEATURE_LEVELS: 3\n    FEATURE_ORDER: \"low2high\"\n  DECODER:\n    NAME: openseed_decoder\n    TRANSFORMER_IN_FEATURE: \"multi_scale_pixel_decoder\"\n    MASK: True\n    BOX: True\n    GROUNDING:\n      ENABLED: False\n      MAX_LEN: 5\n      TEXT_WEIGHT: 2.0\n      CLASS_WEIGHT: 0.5\n    CAPTION:\n      ENABLED: False\n      PHRASE_PROB: 0.0\n      SIM_THRES: 0.95\n    CAPTIONING:\n      ENABLED: False\n      STEP: 50\n    RETRIEVAL:\n      ENABLED: False\n      DIM_IMG: 768\n      ENSEMBLE: True\n    OPENIMAGE:\n      ENABLED: False\n      NEGATIVE_SAMPLES: 5\n      GROUNDING:\n        ENABLED: False\n        MAX_LEN: 5\n    DEEP_SUPERVISION: True\n    NO_OBJECT_WEIGHT: 0.1\n    CLASS_WEIGHT: 4.0\n    MASK_WEIGHT: 5.0\n    DICE_WEIGHT: 5.0\n    BOX_WEIGHT: 5.0\n    GIOU_WEIGHT: 2.0\n    COST_CLASS_WEIGHT: 4.0\n    COST_DICE_WEIGHT: 5.0\n    COST_MASK_WEIGHT: 5.0\n    COST_BOX_WEIGHT: 5.0\n    COST_GIOU_WEIGHT: 2.0\n    HIDDEN_DIM: 256\n    NUM_OBJECT_QUERIES: 300\n    NHEADS: 8\n    DROPOUT: 0.0\n    DIM_FEEDFORWARD: 2048\n    ENC_LAYERS: 0\n    PRE_NORM: False\n    ENFORCE_INPUT_PROJ: False\n    SIZE_DIVISIBILITY: 32\n    DEC_LAYERS: 9  # 9 decoder layers, add one for the loss on learnable query\n    TRAIN_NUM_POINTS: 12544\n    OVERSAMPLE_RATIO: 3.0\n    IMPORTANCE_SAMPLE_RATIO: 0.75\n    TWO_STAGE: True\n    INITIALIZE_BOX_TYPE: 'no'\n    DN: seg\n    DN_NOISE_SCALE: 0.4\n    DN_NUM: 100\n    INITIAL_PRED: True\n    LEARN_TGT: False\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    SEMANTIC_CE_LOSS: False\n    PANO_BOX_LOSS: False\n    COCO: True\n    O365: False\n    TEST:\n      SEMANTIC_ON: True\n      INSTANCE_ON: True\n      PANOPTIC_ON: True\n      OVERLAP_THRESHOLD: 0.8\n      OBJECT_MASK_THRESHOLD: 0.25\n      SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false\n      TEST_FOUCUS_ON_BOX: False\n      PANO_TRANSFORM_EVAL: True\n      PANO_TEMPERATURE: 0.06\n\nTEST:\n  EVAL_PERIOD: 500000\n  PRECISE_BN:\n    NUM_ITER: 1\n    ENABLED: False\n  AUG:\n    ENABLED: False\n\nSAM:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 0.99\n    MAX_SCALE: 1.01\n    DATASET_MAPPER_NAME: \"sam\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'sam'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nCOCO:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_ref_panoptic_lsj\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nVLP:\n  INPUT:\n    IMAGE_SIZE: 224\n    DATASET_MAPPER_NAME: \"vlpretrain\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TRAIN:\n    BATCH_SIZE_TOTAL: 2\n    BATCH_SIZE_PER_GPU: 2\n  TEST:\n    BATCH_SIZE_TOTAL: 256\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nINPUT:\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nDATASETS:\n  TRAIN: [\"flickr_train\",\"coco_2017_train_panoptic_ref_full_with_sem_seg_caption_grounding\"]\n\n  TEST: [\"flickr_val\"]\n\n  CLASS_CONCAT: false\n  SIZE_DIVISIBILITY: 32\n  PROPOSAL_FILES_TRAIN: []\n\nDATALOADER:\n  FILTER_EMPTY_ANNOTATIONS: False\n  NUM_WORKERS: 16\n  LOAD_PROPOSALS: False\n  SAMPLER_TRAIN: \"TrainingSampler\"\n  ASPECT_RATIO_GROUPING: True\n\n# Detectron2 training config for optimizer and lr scheduler\nSOLVER:\n  BASE_LR_END: 0.0\n  MOMENTUM: 0.9\n  NESTEROV: False\n  CHECKPOINT_PERIOD: 5000\n  IMS_PER_BATCH: 1\n  REFERENCE_WORLD_SIZE: 0\n  BIAS_LR_FACTOR: 1.0\n  WEIGHT_DECAY_BIAS: None\n  # original\n  BASE_LR: 0.0001\n  STEPS: [327778, 355092]\n  MAX_ITER: 368750\n  GAMMA: 0.1\n  WARMUP_FACTOR: 1.0\n  WARMUP_ITERS: 10\n  WARMUP_METHOD: \"linear\"\n  WEIGHT_DECAY: 0.05\n  OPTIMIZER: \"ADAMW\"\n  LR_SCHEDULER_NAME: \"WarmupMultiStepLR\"\n  LR_MULTIPLIER:\n    backbone: 0.1\n    lang_encoder: 0.1\n  WEIGHT_DECAY_NORM: 0.0\n  WEIGHT_DECAY_EMBED: 0.0\n  CLIP_GRADIENTS:\n    ENABLED: True\n    CLIP_TYPE: \"full_model\"\n    CLIP_VALUE: 0.01\n    NORM_TYPE: 2.0\n  AMP:\n    ENABLED: True\n\n# Evaluation Dataset\nADE20K:\n  INPUT:\n    MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 640\n    MAX_SIZE_TRAIN: 2560\n    MAX_SIZE_TEST: 2560\n    MASK_FORMAT: \"polygon\"\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [640, 640]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: 640  # used in dataset mapper\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    FORMAT: \"RGB\"\n  DATASET:\n    DATASET: 'ade'\n  TRAIN:\n    ASPECT_RATIO_GROUPING: true\n    BATCH_SIZE_TOTAL: 16\n    BATCH_SIZE_PER_GPU: 2\n    SHUFFLE: true\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nREF:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n    FORMAT: \"RGB\"\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSUN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSCAN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nBDD:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nCITY:\n  INPUT:\n    MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 1024\n    MAX_SIZE_TRAIN: 4096\n    MAX_SIZE_TEST: 2048\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [ 512, 1024 ]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: -1\n    FORMAT: \"RGB\"\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    MASK_FORMAT: \"polygon\"\n    TEST:\n      EVAL_PERIOD: 5000\n      BATCH_SIZE_TOTAL: 1\n      AUG:\n        ENABLED: False\n        MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ]\n        MAX_SIZE: 4096\n        FLIP: True\n    DATALOADER:\n      FILTER_EMPTY_ANNOTATIONS: True\n      NUM_WORKERS: 16\n      LOAD_PROPOSALS: False\n      SAMPLER_TRAIN: \"TrainingSampler\"\n      ASPECT_RATIO_GROUPING: True\n    TRAIN:\n      ASPECT_RATIO_GROUPING: true\n      BATCH_SIZE_TOTAL: 2\n      BATCH_SIZE_PER_GPU: 2\n      SHUFFLE: true\n\nPSACAL_PART:\n  INPUT:\n      MIN_SIZE_TEST: 800\n      MAX_SIZE_TEST: 1333\n      IMAGE_SIZE: 1024\n      MIN_SCALE: 0.1\n      MAX_SCALE: 2.0\n      DATASET_MAPPER_NAME: \"pascal_part_lsj\"\n      IGNORE_VALUE: 255\n      COLOR_AUG_SSD: False\n      SIZE_DIVISIBILITY: 32\n      RANDOM_FLIP: \"horizontal\"\n      MASK_FORMAT: \"polygon\"\n      FORMAT: \"RGB\"\n      CROP:\n        ENABLED: True\n  MODEL:\n    MASK_ON: True\n    KEYPOINT_ON: False\n    LOAD_PROPOSALS: False\n  # DATASET:\n  #   DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nllava:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"llava\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nflickr:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"flickr\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nvg:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"vg\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True"
  },
  {
    "path": "configs/openseed/openseed_swint_lang_joint_2st.yaml",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n##################\n# Task settings\n##################\nWEIGHT: ''\nPORT: 53711\ndetach_seg: False\nVERBOSE: true\n#OUTPUT_DIR: '../../data/output/test'\ninference_only: true\nOUTPUT_DIR: '../../data/output/test'\nclip: true\n# misc\nLOADER:\n  JOINT: True\n  KEY_DATASET: 'flickr'\n# model\nMODEL:\n  NAME: openseed_model\n  HEAD: openseed_head\n  MASK_ON: false\n  KEYPOINT_ON: false\n  LOAD_PROPOSALS: false\n  DIM_PROJ: 4096\n  BACKBONE_DIM: 768\n  BACKGROUND: False\n  WEIGHTS: ''\n  TEXT:\n    ARCH: encoder\n    NAME: transformer\n    TOKENIZER: clip\n    CONTEXT_LENGTH: 18 # 18\n    WIDTH: 512\n    HEADS: 8\n    LAYERS: 12\n    AUTOGRESSIVE: True\n  BACKBONE:\n    NAME: swin\n    PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'\n    LOAD_PRETRAINED: true\n    SWIN:\n      PRETRAIN_IMG_SIZE: 224\n      PATCH_SIZE: 4\n      EMBED_DIM: 96\n      DEPTHS: [ 2, 2, 6, 2 ]\n      NUM_HEADS: [ 3, 6, 12, 24 ]\n      WINDOW_SIZE: 7\n      MLP_RATIO: 4.0\n      QKV_BIAS: true\n      QK_SCALE: ~\n      DROP_RATE: 0.0\n      ATTN_DROP_RATE: 0.0\n      DROP_PATH_RATE: 0.3\n      APE: false\n      PATCH_NORM: true\n      USE_CHECKPOINT: false\n      OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ]\n  ENCODER:\n    NAME: encoder_deform\n    IGNORE_VALUE: 255\n    NUM_CLASSES: 133\n    LOSS_WEIGHT: 1.0\n    CONVS_DIM: 256\n    MASK_DIM: 256\n    NORM: \"GN\"\n    IN_FEATURES: [ \"res2\", \"res3\", \"res4\", \"res5\" ]\n    DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ \"res3\", \"res4\", \"res5\" ]\n    COMMON_STRIDE: 4\n    TRANSFORMER_ENC_LAYERS: 6\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    NUM_FEATURE_LEVELS: 3\n    FEATURE_ORDER: \"low2high\"\n  DECODER:\n    NAME: openseed_decoder\n    TRANSFORMER_IN_FEATURE: \"multi_scale_pixel_decoder\"\n    MASK: True\n    BOX: True\n    COCO_ONLY: True\n    GROUNDING:\n      ENABLED: False\n      MAX_LEN: 5\n      TEXT_WEIGHT: 2.0\n      CLASS_WEIGHT: 0.5\n    CAPTION:\n      ENABLED: False\n      PHRASE_PROB: 0.0\n      SIM_THRES: 0.95\n    CAPTIONING:\n      ENABLED: False\n      STEP: 50\n    RETRIEVAL:\n      ENABLED: False\n      DIM_IMG: 768\n      ENSEMBLE: True\n    OPENIMAGE:\n      ENABLED: False\n      NEGATIVE_SAMPLES: 5\n      GROUNDING:\n        ENABLED: False\n        MAX_LEN: 5\n    DEEP_SUPERVISION: True\n    NO_OBJECT_WEIGHT: 0.1\n    CLASS_WEIGHT: 4.0\n    MASK_WEIGHT: 5.0\n    DICE_WEIGHT: 5.0\n    BOX_WEIGHT: 5.0\n    GIOU_WEIGHT: 2.0\n    LLM_WEIGHT: 1.0\n    WEIGHT_MULTIPLIER: 1.0\n    COST_CLASS_WEIGHT: 4.0\n    COST_DICE_WEIGHT: 5.0\n    COST_MASK_WEIGHT: 5.0\n    COST_BOX_WEIGHT: 5.0\n    COST_GIOU_WEIGHT: 2.0\n    HIDDEN_DIM: 256\n    NUM_OBJECT_QUERIES: 300\n    NHEADS: 8\n    DROPOUT: 0.0\n    DIM_FEEDFORWARD: 2048\n    ENC_LAYERS: 0\n    PRE_NORM: False\n    ENFORCE_INPUT_PROJ: False\n    SIZE_DIVISIBILITY: 32\n    DEC_LAYERS: 9  # 9 decoder layers, add one for the loss on learnable query\n    TRAIN_NUM_POINTS: 12544\n    OVERSAMPLE_RATIO: 3.0\n    IMPORTANCE_SAMPLE_RATIO: 0.75\n    TWO_STAGE: True\n    INITIALIZE_BOX_TYPE: 'no'\n    DN: seg\n    DN_NOISE_SCALE: 0.4\n    DN_NUM: 100\n    INITIAL_PRED: True\n    LEARN_TGT: False\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    SEMANTIC_CE_LOSS: False\n    PANO_BOX_LOSS: False\n    COCO: True\n    O365: False\n    TEST:\n      SEMANTIC_ON: True\n      INSTANCE_ON: True\n      PANOPTIC_ON: True\n      OVERLAP_THRESHOLD: 0.8\n      OBJECT_MASK_THRESHOLD: 0.25\n      SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false\n      TEST_FOUCUS_ON_BOX: False\n      PANO_TRANSFORM_EVAL: True\n      PANO_TEMPERATURE: 0.06\n\nTEST:\n  EVAL_PERIOD: 500000\n  PRECISE_BN:\n    NUM_ITER: 1\n    ENABLED: False\n  AUG:\n    ENABLED: False\n\nSAM:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 0.99\n    MAX_SCALE: 1.01\n    DATASET_MAPPER_NAME: \"sam\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'sam'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nCOCO:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_ref_panoptic_lsj\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nVLP:\n  INPUT:\n    IMAGE_SIZE: 224\n    DATASET_MAPPER_NAME: \"vlpretrain\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TRAIN:\n    BATCH_SIZE_TOTAL: 2\n    BATCH_SIZE_PER_GPU: 2\n  TEST:\n    BATCH_SIZE_TOTAL: 256\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nINPUT:\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nDATASETS:\n  TRAIN: [\"coco_instruct_train_v3\",\"flickr_train\"]\n\nDATALOADER:\n  FILTER_EMPTY_ANNOTATIONS: False\n  NUM_WORKERS:  4\n  LOAD_PROPOSALS: False\n  SAMPLER_TRAIN: \"TrainingSampler\"\n  ASPECT_RATIO_GROUPING: True\n\n# Detectron2 training config for optimizer and lr scheduler\nSOLVER:\n  BASE_LR_END: 0.0\n  MOMENTUM: 0.9\n  NESTEROV: False\n  CHECKPOINT_PERIOD: 5000\n  IMS_PER_BATCH: 1\n  REFERENCE_WORLD_SIZE: 0\n  BIAS_LR_FACTOR: 1.0\n  WEIGHT_DECAY_BIAS: None\n  # original\n  BASE_LR: 0.0001\n  STEPS: [327778, 355092]\n  MAX_ITER: 368750\n  GAMMA: 0.1\n  WARMUP_FACTOR: 1.0\n  WARMUP_ITERS: 10\n  WARMUP_METHOD: \"linear\"\n  WEIGHT_DECAY: 0.05\n  OPTIMIZER: \"ADAMW\"\n  LR_SCHEDULER_NAME: \"WarmupMultiStepLR\"\n  LR_MULTIPLIER:\n    backbone: 0.1\n    lang_encoder: 0.1\n  WEIGHT_DECAY_NORM: 0.0\n  WEIGHT_DECAY_EMBED: 0.0\n  CLIP_GRADIENTS:\n    ENABLED: True\n    CLIP_TYPE: \"full_model\"\n    CLIP_VALUE: 0.01\n    NORM_TYPE: 2.0\n  AMP:\n    ENABLED: True\n\n# Evaluation Dataset\nADE20K:\n  INPUT:\n    MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 640\n    MAX_SIZE_TRAIN: 2560\n    MAX_SIZE_TEST: 2560\n    MASK_FORMAT: \"polygon\"\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [640, 640]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: 640  # used in dataset mapper\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    FORMAT: \"RGB\"\n  DATASET:\n    DATASET: 'ade'\n  TRAIN:\n    ASPECT_RATIO_GROUPING: true\n    BATCH_SIZE_TOTAL: 16\n    BATCH_SIZE_PER_GPU: 2\n    SHUFFLE: true\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nREF:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n    FORMAT: \"RGB\"\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSUN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSCAN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nBDD:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nCITY:\n  INPUT:\n    MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 1024\n    MAX_SIZE_TRAIN: 4096\n    MAX_SIZE_TEST: 2048\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [ 512, 1024 ]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: -1\n    FORMAT: \"RGB\"\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    MASK_FORMAT: \"polygon\"\n    TEST:\n      EVAL_PERIOD: 5000\n      BATCH_SIZE_TOTAL: 1\n      AUG:\n        ENABLED: False\n        MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ]\n        MAX_SIZE: 4096\n        FLIP: True\n    DATALOADER:\n      FILTER_EMPTY_ANNOTATIONS: True\n      NUM_WORKERS:  4\n      LOAD_PROPOSALS: False\n      SAMPLER_TRAIN: \"TrainingSampler\"\n      ASPECT_RATIO_GROUPING: True\n    TRAIN:\n      ASPECT_RATIO_GROUPING: true\n      BATCH_SIZE_TOTAL: 2\n      BATCH_SIZE_PER_GPU: 2\n      SHUFFLE: true\n\nPSACAL_PART:\n  INPUT:\n      MIN_SIZE_TEST: 800\n      MAX_SIZE_TEST: 1333\n      IMAGE_SIZE: 1024\n      MIN_SCALE: 0.1\n      MAX_SCALE: 2.0\n      DATASET_MAPPER_NAME: \"pascal_part_lsj\"\n      IGNORE_VALUE: 255\n      COLOR_AUG_SSD: False\n      SIZE_DIVISIBILITY: 32\n      RANDOM_FLIP: \"horizontal\"\n      MASK_FORMAT: \"polygon\"\n      FORMAT: \"RGB\"\n      CROP:\n        ENABLED: True\n  MODEL:\n    MASK_ON: True\n    KEYPOINT_ON: False\n    LOAD_PROPOSALS: False\n  # DATASET:\n  #   DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nllava:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"llava\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nflickr:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"flickr\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\ncoco_instruct:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_instruct\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nvg:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"vg\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True"
  },
  {
    "path": "configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n##################\n# Task settings\n##################\nWEIGHT: ''\nPORT: 53711\ndetach_seg: False\nVERBOSE: true\n#OUTPUT_DIR: '../../data/output/test'\ninference_only: true\nOUTPUT_DIR: '../../data/output/test'\nclip: true\n# misc\nLOADER:\n  JOINT: True\n  KEY_DATASET: 'flickr'\n# model\nMODEL:\n  NAME: openseed_model\n  HEAD: openseed_head\n  MASK_ON: false\n  KEYPOINT_ON: false\n  LOAD_PROPOSALS: false\n  DIM_PROJ: 4096\n  BACKBONE_DIM: 768\n  BACKGROUND: False\n  WEIGHTS: ''\n  TEXT:\n    ARCH: encoder\n    NAME: transformer\n    TOKENIZER: clip\n    CONTEXT_LENGTH: 18 # 18\n    WIDTH: 512\n    HEADS: 8\n    LAYERS: 12\n    AUTOGRESSIVE: True\n  BACKBONE:\n    NAME: swin\n    PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'\n    LOAD_PRETRAINED: true\n    SWIN:\n      PRETRAIN_IMG_SIZE: 224\n      PATCH_SIZE: 4\n      EMBED_DIM: 96\n      DEPTHS: [ 2, 2, 6, 2 ]\n      NUM_HEADS: [ 3, 6, 12, 24 ]\n      WINDOW_SIZE: 7\n      MLP_RATIO: 4.0\n      QKV_BIAS: true\n      QK_SCALE: ~\n      DROP_RATE: 0.0\n      ATTN_DROP_RATE: 0.0\n      DROP_PATH_RATE: 0.3\n      APE: false\n      PATCH_NORM: true\n      USE_CHECKPOINT: false\n      OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ]\n  ENCODER:\n    NAME: encoder_deform\n    IGNORE_VALUE: 255\n    NUM_CLASSES: 133\n    LOSS_WEIGHT: 1.0\n    CONVS_DIM: 256\n    MASK_DIM: 256\n    NORM: \"GN\"\n    IN_FEATURES: [ \"res2\", \"res3\", \"res4\", \"res5\" ]\n    DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ \"res3\", \"res4\", \"res5\" ]\n    COMMON_STRIDE: 4\n    TRANSFORMER_ENC_LAYERS: 6\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    NUM_FEATURE_LEVELS: 3\n    FEATURE_ORDER: \"low2high\"\n  DECODER:\n    NAME: openseed_decoder\n    TRANSFORMER_IN_FEATURE: \"multi_scale_pixel_decoder\"\n    MASK: True\n    BOX: True\n    COCO_ONLY: True\n    GROUNDING:\n      ENABLED: False\n      MAX_LEN: 5\n      TEXT_WEIGHT: 2.0\n      CLASS_WEIGHT: 0.5\n    CAPTION:\n      ENABLED: False\n      PHRASE_PROB: 0.0\n      SIM_THRES: 0.95\n    CAPTIONING:\n      ENABLED: False\n      STEP: 50\n    RETRIEVAL:\n      ENABLED: False\n      DIM_IMG: 768\n      ENSEMBLE: True\n    OPENIMAGE:\n      ENABLED: False\n      NEGATIVE_SAMPLES: 5\n      GROUNDING:\n        ENABLED: False\n        MAX_LEN: 5\n    DEEP_SUPERVISION: True\n    NO_OBJECT_WEIGHT: 0.1\n    CLASS_WEIGHT: 4.0\n    MASK_WEIGHT: 5.0\n    DICE_WEIGHT: 5.0\n    BOX_WEIGHT: 5.0\n    GIOU_WEIGHT: 2.0\n    LLM_WEIGHT: 1.0\n    WEIGHT_MULTIPLIER: 1.0\n    COST_CLASS_WEIGHT: 4.0\n    COST_DICE_WEIGHT: 5.0\n    COST_MASK_WEIGHT: 5.0\n    COST_BOX_WEIGHT: 5.0\n    COST_GIOU_WEIGHT: 2.0\n    HIDDEN_DIM: 256\n    NUM_OBJECT_QUERIES: 300\n    NHEADS: 8\n    DROPOUT: 0.0\n    DIM_FEEDFORWARD: 2048\n    ENC_LAYERS: 0\n    PRE_NORM: False\n    ENFORCE_INPUT_PROJ: False\n    SIZE_DIVISIBILITY: 32\n    DEC_LAYERS: 9  # 9 decoder layers, add one for the loss on learnable query\n    TRAIN_NUM_POINTS: 12544\n    OVERSAMPLE_RATIO: 3.0\n    IMPORTANCE_SAMPLE_RATIO: 0.75\n    TWO_STAGE: True\n    INITIALIZE_BOX_TYPE: 'no'\n    DN: seg\n    DN_NOISE_SCALE: 0.4\n    DN_NUM: 100\n    INITIAL_PRED: True\n    LEARN_TGT: False\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    SEMANTIC_CE_LOSS: False\n    PANO_BOX_LOSS: False\n    COCO: True\n    O365: False\n    TEST:\n      SEMANTIC_ON: True\n      INSTANCE_ON: True\n      PANOPTIC_ON: True\n      OVERLAP_THRESHOLD: 0.8\n      OBJECT_MASK_THRESHOLD: 0.25\n      SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false\n      TEST_FOUCUS_ON_BOX: False\n      PANO_TRANSFORM_EVAL: True\n      PANO_TEMPERATURE: 0.06\n\nTEST:\n  EVAL_PERIOD: 500000\n  PRECISE_BN:\n    NUM_ITER: 1\n    ENABLED: False\n  AUG:\n    ENABLED: False\n\nSAM:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 0.99\n    MAX_SCALE: 1.01\n    DATASET_MAPPER_NAME: \"sam\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'sam'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nCOCO:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_ref_panoptic_lsj\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nVLP:\n  INPUT:\n    IMAGE_SIZE: 224\n    DATASET_MAPPER_NAME: \"vlpretrain\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TRAIN:\n    BATCH_SIZE_TOTAL: 2\n    BATCH_SIZE_PER_GPU: 2\n  TEST:\n    BATCH_SIZE_TOTAL: 256\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nINPUT:\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nDATASETS:\n  TRAIN: [\"coco_interactive_refcoco\",\"coco_interactive\",\"flickr_train\"]\n\nDATALOADER:\n  FILTER_EMPTY_ANNOTATIONS: False\n  NUM_WORKERS:  4\n  LOAD_PROPOSALS: False\n  SAMPLER_TRAIN: \"TrainingSampler\"\n  ASPECT_RATIO_GROUPING: True\n\n# Detectron2 training config for optimizer and lr scheduler\nSOLVER:\n  BASE_LR_END: 0.0\n  MOMENTUM: 0.9\n  NESTEROV: False\n  CHECKPOINT_PERIOD: 5000\n  IMS_PER_BATCH: 1\n  REFERENCE_WORLD_SIZE: 0\n  BIAS_LR_FACTOR: 1.0\n  WEIGHT_DECAY_BIAS: None\n  # original\n  BASE_LR: 0.0001\n  STEPS: [327778, 355092]\n  MAX_ITER: 368750\n  GAMMA: 0.1\n  WARMUP_FACTOR: 1.0\n  WARMUP_ITERS: 10\n  WARMUP_METHOD: \"linear\"\n  WEIGHT_DECAY: 0.05\n  OPTIMIZER: \"ADAMW\"\n  LR_SCHEDULER_NAME: \"WarmupMultiStepLR\"\n  LR_MULTIPLIER:\n    backbone: 0.1\n    lang_encoder: 0.1\n  WEIGHT_DECAY_NORM: 0.0\n  WEIGHT_DECAY_EMBED: 0.0\n  CLIP_GRADIENTS:\n    ENABLED: True\n    CLIP_TYPE: \"full_model\"\n    CLIP_VALUE: 0.01\n    NORM_TYPE: 2.0\n  AMP:\n    ENABLED: True\n\n# Evaluation Dataset\nADE20K:\n  INPUT:\n    MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 640\n    MAX_SIZE_TRAIN: 2560\n    MAX_SIZE_TEST: 2560\n    MASK_FORMAT: \"polygon\"\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [640, 640]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: 640  # used in dataset mapper\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    FORMAT: \"RGB\"\n  DATASET:\n    DATASET: 'ade'\n  TRAIN:\n    ASPECT_RATIO_GROUPING: true\n    BATCH_SIZE_TOTAL: 16\n    BATCH_SIZE_PER_GPU: 2\n    SHUFFLE: true\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nREF:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n    FORMAT: \"RGB\"\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSUN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSCAN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nBDD:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nCITY:\n  INPUT:\n    MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 1024\n    MAX_SIZE_TRAIN: 4096\n    MAX_SIZE_TEST: 2048\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [ 512, 1024 ]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: -1\n    FORMAT: \"RGB\"\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    MASK_FORMAT: \"polygon\"\n    TEST:\n      EVAL_PERIOD: 5000\n      BATCH_SIZE_TOTAL: 1\n      AUG:\n        ENABLED: False\n        MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ]\n        MAX_SIZE: 4096\n        FLIP: True\n    DATALOADER:\n      FILTER_EMPTY_ANNOTATIONS: True\n      NUM_WORKERS:  4\n      LOAD_PROPOSALS: False\n      SAMPLER_TRAIN: \"TrainingSampler\"\n      ASPECT_RATIO_GROUPING: True\n    TRAIN:\n      ASPECT_RATIO_GROUPING: true\n      BATCH_SIZE_TOTAL: 2\n      BATCH_SIZE_PER_GPU: 2\n      SHUFFLE: true\n\nPSACAL_PART:\n  INPUT:\n      MIN_SIZE_TEST: 800\n      MAX_SIZE_TEST: 1333\n      IMAGE_SIZE: 1024\n      MIN_SCALE: 0.1\n      MAX_SCALE: 2.0\n      DATASET_MAPPER_NAME: \"pascal_part_lsj\"\n      IGNORE_VALUE: 255\n      COLOR_AUG_SSD: False\n      SIZE_DIVISIBILITY: 32\n      RANDOM_FLIP: \"horizontal\"\n      MASK_FORMAT: \"polygon\"\n      FORMAT: \"RGB\"\n      CROP:\n        ENABLED: True\n  MODEL:\n    MASK_ON: True\n    KEYPOINT_ON: False\n    LOAD_PROPOSALS: False\n  # DATASET:\n  #   DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nllava:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"llava\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nflickr:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"flickr\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\ncoco_instruct:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_instruct\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\ncoco_interactive:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_interactive\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\n\nvg:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"vg\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS:  4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True"
  },
  {
    "path": "configs/semsam/visual_prompt_encoder.yaml",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n##################\n# Task settings\n##################\nWEIGHT: ''\nPORT: 53711\nVERBOSE: true\n#OUTPUT_DIR: '../../data/output/test'\ninference_only: true\nOUTPUT_DIR: '../../data/output/test'\n# misc\nLOADER:\n  JOINT: True\n  KEY_DATASET: 'coco'\n# model\nMODEL:\n  NAME: idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1\n  HEAD: openseed_head\n  MASK_ON: false\n  KEYPOINT_ON: false\n  LOAD_PROPOSALS: false\n  DIM_PROJ: 512\n  BACKBONE_DIM: 768\n  BACKGROUND: False\n  WEIGHTS: None\n  LLAMA:\n    model_name_or_path: '/comp_robot/liushilong/data/LLAVA/LLAVA_7b'\n    cache_dir: None\n    model_max_length: 2048\n    hidden_size: 4096\n    tune_mm_mlp_adapter: True\n    im_width: 16\n    load_fp16: False\n    lora_r: 0\n    lora_alpha: 16\n    lora_dropout: 0.05\n\n  TEXT:\n    ARCH: llama_encoder\n  BACKBONE:\n    NAME: swin\n    PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'\n    LOAD_PRETRAINED: true\n    SWIN:\n      PRETRAIN_IMG_SIZE: 224\n      PATCH_SIZE: 4\n      EMBED_DIM: 96\n      DEPTHS: [ 2, 2, 6, 2 ]\n      NUM_HEADS: [ 3, 6, 12, 24 ]\n      WINDOW_SIZE: 7\n      MLP_RATIO: 4.0\n      QKV_BIAS: true\n      QK_SCALE: ~\n      DROP_RATE: 0.0\n      ATTN_DROP_RATE: 0.0\n      DROP_PATH_RATE: 0.3\n      APE: false\n      PATCH_NORM: true\n      USE_CHECKPOINT: false\n      OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ]\n  ENCODER:\n    NAME: encoder_deform\n    IGNORE_VALUE: 255\n    NUM_CLASSES: 1\n    LOSS_WEIGHT: 1.0\n    CONVS_DIM: 256\n    MASK_DIM: 256\n    NORM: \"GN\"\n    IN_FEATURES: [ \"res2\", \"res3\", \"res4\", \"res5\" ]\n    DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ \"res3\", \"res4\", \"res5\" ]\n    COMMON_STRIDE: 4\n    TRANSFORMER_ENC_LAYERS: 6\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    NUM_FEATURE_LEVELS: 3\n    FEATURE_ORDER: \"low2high\"\n  DECODER:\n    NAME: idino_decoder_no_iou_token_partwhole_all_llm\n    TRANSFORMER_IN_FEATURE: \"multi_scale_pixel_decoder\"\n    MASK: True\n    BOX: True\n    PART: True\n    pretrain: True\n    match_loss: True\n    GROUNDING:\n      ENABLED: True\n      MAX_LEN: 5\n      TEXT_WEIGHT: 2.0\n      CLASS_WEIGHT: 0.5\n    CAPTION:\n      ENABLED: True\n      PHRASE_PROB: 0.0\n      SIM_THRES: 0.95\n    CAPTIONING:\n      ENABLED: True\n      STEP: 50\n    RETRIEVAL:\n      ENABLED: True\n      DIM_IMG: 768\n      ENSEMBLE: True\n    OPENIMAGE:\n      ENABLED: False\n      NEGATIVE_SAMPLES: 5\n      GROUNDING:\n        ENABLED: False\n        MAX_LEN: 5\n    DEEP_SUPERVISION: True\n    NO_OBJECT_WEIGHT: 0.1\n    CLASS_WEIGHT: 4.0\n    MASK_WEIGHT: 5.0\n    DICE_WEIGHT: 5.0\n    BOX_WEIGHT: 5.0\n    GIOU_WEIGHT: 2.0\n    IOU_WEIGHT: 1.0\n    LLAMA_WEIGHT: 5.0\n    llama_det_weight: 2.0\n    llama_ref_weight: 1.0\n    llama_region_cap_weight: 1.0\n    llama_img_cap_weight: 1.0\n    llama_gd_weight: 20.0\n    llama_gd_text_weight: 2.0\n    REFER_WEIGHT: 5.0\n    COST_CLASS_WEIGHT: 4.0\n    COST_DICE_WEIGHT: 5.0\n    COST_MASK_WEIGHT: 5.0\n    COST_BOX_WEIGHT: 5.0\n    COST_GIOU_WEIGHT: 2.0\n    HIDDEN_DIM: 256\n    NUM_OBJECT_QUERIES: 0\n    NHEADS: 8\n    DROPOUT: 0.0\n    DIM_FEEDFORWARD: 2048\n    ENC_LAYERS: 0\n    PRE_NORM: False\n    ENFORCE_INPUT_PROJ: False\n    SIZE_DIVISIBILITY: 32\n    DEC_LAYERS: 9  # 9 decoder layers, add one for the loss on learnable query\n    TRAIN_NUM_POINTS: 12544\n    OVERSAMPLE_RATIO: 3.0\n    IMPORTANCE_SAMPLE_RATIO: 0.75\n    TWO_STAGE: False\n    INITIALIZE_BOX_TYPE: 'no'\n    DN: seg\n    DN_NOISE_SCALE: 0.4\n    DN_NUM: 100\n    INITIAL_PRED: False\n    LEARN_TGT: False\n    TOTAL_NUM_FEATURE_LEVELS: 4\n    SEMANTIC_CE_LOSS: False\n    PANO_BOX_LOSS: False\n    COCO: True\n    O365: False\n    SAM: True\n    PASCAL: True\n    RE_POINT: True\n    NUM_INTERACTIVE_TOKENS: 3\n    TEST:\n      SEMANTIC_ON: True\n      INSTANCE_ON: True\n      PANOPTIC_ON: True\n      OVERLAP_THRESHOLD: 0.8\n      OBJECT_MASK_THRESHOLD: 0.25\n      SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false\n      TEST_FOUCUS_ON_BOX: False\n      PANO_TRANSFORM_EVAL: True\n      PANO_TEMPERATURE: 0.06\n\nTEST:\n  EVAL_PERIOD: 500000\n  PRECISE_BN:\n    NUM_ITER: 1\n    ENABLED: False\n  AUG:\n    ENABLED: False\n\nSAM:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 0.99\n    MAX_SCALE: 1.01\n    DATASET_MAPPER_NAME: \"sam\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'sam'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 4\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nCOCO:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_interactive_panoptic_lsj\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  DATASET:\n    DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nVLP:\n  INPUT:\n    IMAGE_SIZE: 224\n    DATASET_MAPPER_NAME: \"vlpretrain\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TRAIN:\n    BATCH_SIZE_TOTAL: 2\n    BATCH_SIZE_PER_GPU: 2\n  TEST:\n    BATCH_SIZE_TOTAL: 256\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 16\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nINPUT:\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nDATASETS:\n  TRAIN: [\"coco_2017_train_panoptic_filtrefgumdval_with_sem_seg_caption_grounding\",\"mapillary_vistas_panoptic_train\",\"ade20k_panoptic_train\",\"sam_train\",\"pascal_part_train\",\"paco_train\",\"partimagenet_train\"]#,\"sam_train\",\"pascal_part_train\"]#,\"paco_train\",\"partimagenet_train\"]\n\nDATALOADER:\n  FILTER_EMPTY_ANNOTATIONS: False\n  NUM_WORKERS: 16\n  LOAD_PROPOSALS: False\n  SAMPLER_TRAIN: \"TraziningSampler\"\n  ASPECT_RATIO_GROUPING: True\n\n# Detectron2 training config for optimizer and lr scheduler\nSOLVER:\n  BASE_LR_END: 0.0\n  MOMENTUM: 0.9\n  NESTEROV: False\n  CHECKPOINT_PERIOD: 5000\n  IMS_PER_BATCH: 1\n  REFERENCE_WORLD_SIZE: 0\n  BIAS_LR_FACTOR: 1.0\n  WEIGHT_DECAY_BIAS: None\n  # original\n  BASE_LR: 0.0001\n  STEPS: [327778, 355092]\n  MAX_ITER: 368750\n  GAMMA: 0.1\n  WARMUP_FACTOR: 1.0\n  WARMUP_ITERS: 10\n  WARMUP_METHOD: \"linear\"\n  WEIGHT_DECAY: 0.05\n  OPTIMIZER: \"ADAMW\"\n  LR_SCHEDULER_NAME: \"WarmupMultiStepLR\"\n  LR_MULTIPLIER:\n    backbone: 0.1\n    lang_encoder: 0.1\n  WEIGHT_DECAY_NORM: 0.0\n  WEIGHT_DECAY_EMBED: 0.0\n  CLIP_GRADIENTS:\n    ENABLED: True\n    CLIP_TYPE: \"full_model\"\n    CLIP_VALUE: 0.01\n    NORM_TYPE: 2.0\n  AMP:\n    ENABLED: True\n\n# Evaluation Dataset\nADE20K:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"coco_interactive_panoptic_lsj\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n\n  DATASET:\n    DATASET: 'ade'\n  TRAIN:\n    ASPECT_RATIO_GROUPING: true\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 8\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 8\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nREF:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n    FORMAT: \"RGB\"\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 0\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSUN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 0\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nSCAN:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 512\n    MAX_SIZE_TEST: 1024\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 0\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nBDD:\n  INPUT:\n    PIXEL_MEAN: [123.675, 116.280, 103.530]\n    PIXEL_STD: [58.395, 57.120, 57.375]\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 0\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: False\n  TEST:\n    BATCH_SIZE_TOTAL: 8\n\nCITY:\n  INPUT:\n    MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ]\n    MIN_SIZE_TRAIN_SAMPLING: \"choice\"\n    MIN_SIZE_TEST: 1024\n    MAX_SIZE_TRAIN: 4096\n    MAX_SIZE_TEST: 2048\n    CROP:\n      ENABLED: True\n      TYPE: \"absolute\"\n      SIZE: [ 512, 1024 ]\n      SINGLE_CATEGORY_MAX_AREA: 1.0\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: True\n    SIZE_DIVISIBILITY: -1\n    FORMAT: \"RGB\"\n    DATASET_MAPPER_NAME: \"mask_former_panoptic\"\n    MASK_FORMAT: \"polygon\"\n    TEST:\n      EVAL_PERIOD: 5000\n      BATCH_SIZE_TOTAL: 1\n      AUG:\n        ENABLED: False\n        MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ]\n        MAX_SIZE: 4096\n        FLIP: True\n    DATALOADER:\n      FILTER_EMPTY_ANNOTATIONS: True\n      NUM_WORKERS: 2\n      LOAD_PROPOSALS: False\n      SAMPLER_TRAIN: \"TrainingSampler\"\n      ASPECT_RATIO_GROUPING: True\n    TRAIN:\n      ASPECT_RATIO_GROUPING: true\n      BATCH_SIZE_TOTAL: 2\n      BATCH_SIZE_PER_GPU: 2\n      SHUFFLE: true\n\nPSACAL_PART:\n  INPUT:\n      MIN_SIZE_TEST: 800\n      MAX_SIZE_TEST: 1333\n      IMAGE_SIZE: 1024\n      MIN_SCALE: 1.0\n      MAX_SCALE: 1.0\n      DATASET_MAPPER_NAME: \"pascal_part_lsj\"\n      IGNORE_VALUE: 255\n      COLOR_AUG_SSD: False\n      SIZE_DIVISIBILITY: 32\n      RANDOM_FLIP: \"horizontal\"\n      MASK_FORMAT: \"polygon\"\n      FORMAT: \"RGB\"\n      CROP:\n        ENABLED: True\n  MODEL:\n    MASK_ON: True\n    KEYPOINT_ON: False\n    LOAD_PROPOSALS: False\n  # DATASET:\n  #   DATASET: 'coco'\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 8\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nllava:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"llava\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nflickr:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"flickr\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\npart:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"part\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True\n\nvg:\n  INPUT:\n    MIN_SIZE_TEST: 800\n    MAX_SIZE_TEST: 1333\n    IMAGE_SIZE: 1024\n    MIN_SCALE: 1.0\n    MAX_SCALE: 1.0\n    DATASET_MAPPER_NAME: \"vg\"\n    IGNORE_VALUE: 255\n    COLOR_AUG_SSD: False\n    SIZE_DIVISIBILITY: 32\n    RANDOM_FLIP: \"horizontal\"\n    MASK_FORMAT: \"polygon\"\n    FORMAT: \"RGB\"\n    CROP:\n      ENABLED: True\n  TEST:\n    DETECTIONS_PER_IMAGE: 100\n    NAME: coco_eval\n    IOU_TYPE: ['bbox', 'segm']\n    USE_MULTISCALE: false\n    BATCH_SIZE_TOTAL: 1\n    MODEL_FILE: ''\n    AUG:\n      ENABLED: False\n  TRAIN:\n    BATCH_SIZE_TOTAL: 1\n    BATCH_SIZE_PER_GPU: 1\n    SHUFFLE: true\n  DATALOADER:\n    FILTER_EMPTY_ANNOTATIONS: False\n    NUM_WORKERS: 2\n    LOAD_PROPOSALS: False\n    SAMPLER_TRAIN: \"TrainingSampler\"\n    ASPECT_RATIO_GROUPING: True"
  },
  {
    "path": "datasets_os/__init__.py",
    "content": "from . import registration\nfrom .build import *"
  },
  {
    "path": "datasets_os/build.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport os\nimport itertools\nimport logging\nimport copy\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.utils.data\nimport torch.utils.data as torchdata\n\nimport detectron2.utils.comm as comm\nfrom detectron2.data.build import (\n    build_batch_data_loader,\n    load_proposals_into_dataset,\n    trivial_batch_collator,\n)\nfrom detectron2.data import MetadataCatalog\nfrom detectron2.data.catalog import DatasetCatalog\nfrom detectron2.data.common import DatasetFromList, MapDataset\nfrom detectron2.data.dataset_mapper import DatasetMapper\nfrom detectron2.data.samplers import InferenceSampler, TrainingSampler\n\nfrom fvcore.common.config import CfgNode\nfrom omegaconf import DictConfig, OmegaConf\n\nfrom .dataset_mappers import (\n    COCOPanopticInteractiveDatasetMapper,\n    FlickrNewBaselineDatasetMapper,\n    VGNewBaselineDatasetMapper,\n    COCOInstructGroundingDatasetMapper,\n    COCOInterGroundingDatasetMapper,\n)\n\nfrom .custom_dataset_dataloader import build_custom_test_loader\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.utils.comm import get_world_size, is_main_process\nfrom typing import Any, Dict, List, Set\n\nclass JointLoader(torchdata.IterableDataset):\n    def __init__(self, loaders, key_dataset):\n        dataset_names = []\n        for key, loader in loaders.items():\n            name = \"{}\".format(key.split('_')[0])\n            # name = \"{}\".format(key)\n            setattr(self, name, loader)\n            dataset_names += [name]\n        self.dataset_names = dataset_names\n        self.key_dataset = key_dataset\n    \n    def __iter__(self):\n        for batch in zip(*[getattr(self, name) for name in self.dataset_names]):\n            yield {key: batch[i] for i, key in enumerate(self.dataset_names)}\n\n    def __len__(self):\n        return len(getattr(self, self.key_dataset))\n\ndef filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):\n    \"\"\"\n    Filter out images with none annotations or only crowd annotations\n    (i.e., images without non-crowd annotations).\n    A common training-time preprocessing on COCO dataset.\n\n    Args:\n        dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.\n\n    Returns:\n        list[dict]: the same format, but filtered.\n    \"\"\"\n    num_before = len(dataset_dicts)\n\n    def valid(anns):\n        for ann in anns:\n            if isinstance(ann, list):\n                for instance in ann:\n                    if instance.get(\"iscrowd\", 0) == 0:\n                        return True\n            else:\n                if ann.get(\"iscrowd\", 0) == 0:\n                    return True\n        return False\n\n    dataset_dicts = [x for x in dataset_dicts if valid(x[\"annotations\"])]\n    num_after = len(dataset_dicts)\n    logger = logging.getLogger(__name__)\n    logger.info(\n        \"Removed {} images with no usable annotations. {} images left.\".format(\n            num_before - num_after, num_after\n        )\n    )\n    return dataset_dicts\n\n\ndef get_detection_dataset_dicts(\n    dataset_names, filter_empty=True, proposal_files=None\n):\n    \"\"\"\n    Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.\n\n    Args:\n        dataset_names (str or list[str]): a dataset name or a list of dataset names\n        filter_empty (bool): whether to filter out images without instance annotations\n        proposal_files (list[str]): if given, a list of object proposal files\n            that match each dataset in `dataset_names`.\n\n    Returns:\n        list[dict]: a list of dicts following the standard dataset dict format.\n    \"\"\"\n    if isinstance(dataset_names, str):\n        dataset_names = [dataset_names]\n    assert len(dataset_names)\n    \n    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]\n    for dataset_name, dicts in zip(dataset_names, dataset_dicts):\n        assert len(dicts), \"Dataset '{}' is empty!\".format(dataset_name)\n\n    if proposal_files is not None:\n        assert len(dataset_names) == len(proposal_files)\n        # load precomputed proposals from proposal files\n        dataset_dicts = [\n            load_proposals_into_dataset(dataset_i_dicts, proposal_file)\n            for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)\n        ]\n\n    dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))\n\n    has_instances = \"annotations\" in dataset_dicts[0]\n    if filter_empty and has_instances:\n        dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)\n\n    assert len(dataset_dicts), \"No valid data found in {}.\".format(\",\".join(dataset_names))\n    return dataset_dicts\n\n\ndef _test_loader_from_config(cfg, dataset_name, mapper=None):\n    \"\"\"\n    Uses the given `dataset_name` argument (instead of the names in cfg), because the\n    standard practice is to evaluate each test set individually (not combining them).\n    \"\"\"\n    if isinstance(dataset_name, str):\n        dataset_name = [dataset_name]\n\n    dataset = get_detection_dataset_dicts(\n        dataset_name,\n        filter_empty=False,\n        proposal_files=None,\n    )\n    # import ipdb;ipdb.set_trace()\n    if mapper is None:\n        if isinstance(cfg, (DictConfig)):\n            cfg = OmegaConf.to_container(copy.deepcopy(cfg))\n        mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']})\n        mapper = DatasetMapper(mapper_cfg, False)\n    assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, \"Evaluation total batchsize is not divisible by gpu number\"\n    batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size()\n\n    return {\n        \"dataset\": dataset,\n        \"mapper\": mapper,\n        \"num_workers\": cfg['DATALOADER']['NUM_WORKERS'],\n        \"sampler\": InferenceSampler(len(dataset)),\n        \"batch_size\": batch_size,\n    }\n\n\n@configurable(from_config=_test_loader_from_config)\ndef build_detection_test_loader(\n    dataset: Union[List[Any], torchdata.Dataset],\n    *,\n    mapper: Callable[[Dict[str, Any]], Any],\n    sampler: Optional[torchdata.Sampler] = None,\n    batch_size: int = 1,\n    num_workers: int = 0,\n    collate_fn: Optional[Callable[[List[Any]], Any]] = None,\n) -> torchdata.DataLoader:\n    \"\"\"\n    Similar to `build_detection_train_loader`, with default batch size = 1,\n    and sampler = :class:`InferenceSampler`. This sampler coordinates all workers\n    to produce the exact set of all samples.\n\n    Args:\n        dataset: a list of dataset dicts,\n            or a pytorch dataset (either map-style or iterable). They can be obtained\n            by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.\n        mapper: a callable which takes a sample (dict) from dataset\n           and returns the format to be consumed by the model.\n           When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.\n        sampler: a sampler that produces\n            indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,\n            which splits the dataset across all workers. Sampler must be None\n            if `dataset` is iterable.\n        batch_size: the batch size of the data loader to be created.\n            Default to 1 image per worker since this is the standard when reporting\n            inference time in papers.\n        num_workers: number of parallel data loading workers\n        collate_fn: same as the argument of `torch.utils.data.DataLoader`.\n            Defaults to do no collation and return a list of data.\n\n    Returns:\n        DataLoader: a torch DataLoader, that loads the given detection\n        dataset, with test-time transformation and batching.\n\n    Examples:\n    ::\n        data_loader = build_detection_test_loader(\n            DatasetRegistry.get(\"my_test\"),\n            mapper=DatasetMapper(...))\n\n        # or, instantiate with a CfgNode:\n        data_loader = build_detection_test_loader(cfg, \"my_test\")\n    \"\"\"\n\n    if isinstance(dataset, list):\n        dataset = DatasetFromList(dataset, copy=False)\n    if mapper is not None:\n        dataset = MapDataset(dataset, mapper)\n    if isinstance(dataset, torchdata.IterableDataset):\n        assert sampler is None, \"sampler must be None if dataset is IterableDataset\"\n    else:\n        if sampler is None:\n            sampler = InferenceSampler(len(dataset))\n    return torchdata.DataLoader(\n        dataset,\n        batch_size=batch_size,\n        sampler=sampler,\n        drop_last=False,\n        num_workers=num_workers,\n        collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,\n    )\n\n\ndef _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None):\n    cfg_datasets = cfg['DATASETS']\n    cfg_dataloader = cfg['DATALOADER']\n    \n    if dataset is None:\n        dataset = get_detection_dataset_dicts(\n            dataset_name,\n            filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'],\n            proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None,\n        )\n\n    if mapper is None:\n        mapper = DatasetMapper(cfg, True)\n\n    if sampler is None:\n        sampler_name = cfg_dataloader['SAMPLER_TRAIN']\n        logger = logging.getLogger(__name__)\n        logger.info(\"Using training sampler {}\".format(sampler_name))\n        sampler = TrainingSampler(len(dataset))\n\n    return {\n        \"dataset\": dataset,\n        \"sampler\": sampler,\n        \"mapper\": mapper,\n        \"total_batch_size\": cfg['TRAIN']['BATCH_SIZE_TOTAL'],\n        \"aspect_ratio_grouping\": cfg_dataloader['ASPECT_RATIO_GROUPING'],\n        \"num_workers\": cfg_dataloader['NUM_WORKERS'],\n    }\n\n\n@configurable(from_config=_train_loader_from_config)\ndef build_detection_train_loader(\n    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0\n):\n    \"\"\"\n    Build a dataloader for object detection with some default features.\n    This interface is experimental.\n\n    Args:\n        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,\n            or a map-style pytorch dataset. They can be obtained by using\n            :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.\n        mapper (callable): a callable which takes a sample (dict) from dataset and\n            returns the format to be consumed by the model.\n            When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.\n        sampler (torch.utils.data.sampler.Sampler or None): a sampler that\n            produces indices to be applied on ``dataset``.\n            Default to :class:`TrainingSampler`, which coordinates a random shuffle\n            sequence across all workers.\n        total_batch_size (int): total batch size across all workers. Batching\n            simply puts data into a list.\n        aspect_ratio_grouping (bool): whether to group images with similar\n            aspect ratio for efficiency. When enabled, it requires each\n            element in dataset be a dict with keys \"width\" and \"height\".\n        num_workers (int): number of parallel data loading workers\n\n    Returns:\n        torch.utils.data.DataLoader: a dataloader. Each output from it is a\n            ``list[mapped_element]`` of length ``total_batch_size / num_workers``,\n            where ``mapped_element`` is produced by the ``mapper``.\n    \"\"\"\n    if isinstance(dataset, list):\n        dataset = DatasetFromList(dataset, copy=False)\n    if mapper is not None:\n        dataset = MapDataset(dataset, mapper)\n    if sampler is None:\n        sampler = TrainingSampler(len(dataset))\n    assert isinstance(sampler, torch.utils.data.sampler.Sampler)\n    \n    return build_batch_data_loader(\n        dataset,\n        sampler,\n        total_batch_size,\n        aspect_ratio_grouping=aspect_ratio_grouping,\n        num_workers=num_workers,\n    )\n\n\ndef get_config_from_name(cfg, dataset_name):\n    # adjust config according to dataset\n    if 'sam' in dataset_name:\n        cfg.update(cfg['SAM'])\n        return cfg\n    elif 'flickr' in dataset_name:\n        cfg.update(cfg['flickr'])\n        return cfg\n    elif 'coco_instruct' in dataset_name:\n        cfg.update(cfg['coco_instruct'])\n        return cfg\n    elif 'coco_interactive' in dataset_name:\n        cfg.update(cfg['coco_interactive'])\n        return cfg\n    elif 'lisa' in dataset_name:\n        cfg.update(cfg['LISA_REF'])\n        return cfg\n    elif 'llava' in dataset_name:\n        cfg.update(cfg['llava'])\n        return cfg\n    elif 'vg' in dataset_name:\n        cfg.update(cfg['vg'])\n        return cfg\n    elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name:\n        cfg.update(cfg['part'])\n        return cfg\n    elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name :\n        cfg.update(cfg['PSACAL_PART'])\n        return cfg\n    elif 'coco' in dataset_name and 'refonly' in dataset_name:\n        # if 'COCO' in cfg.keys():\n        cfg.update(cfg['COCO_REF'])\n        return cfg\n    elif 'coco' in dataset_name:\n        if 'COCO' in cfg.keys():\n            cfg.update(cfg['COCO'])\n        return cfg\n    elif \"mapillary\" in dataset_name:\n        if 'MAPILLARY' in cfg.keys():\n            cfg.update(cfg['MAPILLARY'])\n        return cfg\n    elif 'ade' in dataset_name:\n        if 'ADE20K' in cfg.keys():\n            cfg.update(cfg['ADE20K'])\n        return cfg\n    elif 'imagenet' in dataset_name:\n        if 'IMAGENET' in cfg.keys():\n            cfg.update(cfg['IMAGENET'])\n        return cfg\n    elif 'vlp' in dataset_name:\n        cfg.update(cfg['VLP'])\n        return cfg\n    elif 'sun' in dataset_name:\n        cfg.update(cfg['SUN'])\n        return cfg\n    elif 'object365' in dataset_name:\n        cfg.update(cfg['OBJECT365'])\n        return cfg\n    elif 'scan' in dataset_name:\n        cfg.update(cfg['SCAN'])\n        return cfg\n    elif 'cityscape' in dataset_name:\n        cfg.update(cfg['CITY'])\n        return cfg\n    elif 'bdd' in dataset_name:\n        cfg.update(cfg['BDD'])\n        return cfg\n    else:\n        assert False, \"dataset not support.\"\n\n\n\ndef build_train_dataloader(cfg,tokenizer=None,data_args=None,preprocess=None,llava_cap_loader=None ):\n    dataset_names = cfg['DATASETS']['TRAIN']\n    \n    loaders = {}\n    cfg = copy.deepcopy(cfg)\n    for dataset_name in dataset_names:\n        cfg = get_config_from_name(cfg, dataset_name)\n        mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME']\n\n        if mapper_name ==\"flickr\":\n            mapper=FlickrNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess)\n            loaders['flickr'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n        elif mapper_name ==\"coco_instruct\":\n            mapper=COCOInstructGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess)\n            loaders['coco_instruct'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n        elif mapper_name ==\"coco_interactive\":\n            if \"refcoco\" in dataset_name:\n                refcoco=True\n            else:\n                refcoco=False\n            mapper=COCOInterGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess,refcoco=refcoco)\n            if refcoco:\n                loaders['interactiveref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n            else:\n                loaders['interactive'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n        elif mapper_name ==\"vg\":\n            mapper=VGNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess)\n            loaders['vg'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n        elif mapper_name == \"coco_ref_panoptic_lsj\":\n            mapper = COCOPanopticInteractiveDatasetMapper(cfg, cfg.get('Train',True),tokenizer=tokenizer,data_args=data_args,preprocess=preprocess)\n            loaders['refcoco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n        else:\n            mapper = None\n            loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)\n\n    if llava_cap_loader is not None:\n        loaders['llava_cap'] = llava_cap_loader\n    if len(loaders) == 1 and not cfg['LOADER'].get('JOINT', False):\n        for k, v in loaders.items():\n            print(\"number of iterations per epoch: \", v, len(loaders[k]))\n        return list(loaders.values())[0]\n    else:\n        return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco'))"
  },
  {
    "path": "datasets_os/custom_dataset_dataloader.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)\nimport copy\nimport logging\nimport numpy as np\nimport operator\nfrom typing import Any, Callable, Dict, List, Optional, Union\nimport torch\nimport torch.utils.data as torchdata\nimport json\nfrom detectron2.utils.comm import get_world_size\nfrom detectron2.utils.logger import _log_api_usage, log_first_n\n\nfrom detectron2.config import configurable\nfrom detectron2.data import samplers\nfrom torch.utils.data.sampler import BatchSampler, Sampler\nfrom detectron2.data.common import DatasetFromList, MapDataset\nfrom detectron2.data.dataset_mapper import DatasetMapper\nfrom detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader\nfrom detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler, InferenceSampler\nfrom detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram\nfrom detectron2.data.build import filter_images_with_only_crowd_annotations\nfrom detectron2.data.build import filter_images_with_few_keypoints\nfrom detectron2.data.build import check_metadata_consistency\nfrom detectron2.data.catalog import MetadataCatalog, DatasetCatalog\nfrom detectron2.utils import comm\nimport itertools\nimport math\nfrom collections import defaultdict\nfrom typing import Optional\n\nlogger = logging.getLogger('detectron2.vlpart.data.custom_dataset_dataloader')\n\n\ndef _custom_test_loader_from_config(cfg, dataset_name, mapper=None):\n    if isinstance(dataset_name, str):\n        dataset_name = [dataset_name]\n\n    dataset = get_detection_dataset_dicts(\n        dataset_name,\n        filter_empty=False,\n        proposal_files=[\n            cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name\n        ]\n        if cfg.MODEL.LOAD_PROPOSALS_TEST\n        else None,\n    )\n    if mapper is None:\n        mapper = DatasetMapper(cfg, False)\n    return {\n        \"dataset\": dataset,\n        \"mapper\": mapper,\n        \"num_workers\": cfg.DATALOADER.NUM_WORKERS,\n        \"sampler\": InferenceSampler(len(dataset))\n        if not isinstance(dataset, torchdata.IterableDataset)\n        else None,\n    }\n\n\n@configurable(from_config=_custom_test_loader_from_config)\ndef build_custom_test_loader(\n    dataset: Union[List[Any], torchdata.Dataset],\n    *,\n    mapper: Callable[[Dict[str, Any]], Any],\n    sampler: Optional[torchdata.Sampler] = None,\n    batch_size: int = 1,\n    num_workers: int = 0,\n    collate_fn: Optional[Callable[[List[Any]], Any]] = None,\n) -> torchdata.DataLoader:\n    if isinstance(dataset, list):\n        dataset = DatasetFromList(dataset, copy=False)\n    if mapper is not None:\n        dataset = MapDataset(dataset, mapper)\n    if isinstance(dataset, torchdata.IterableDataset):\n        assert sampler is None, \"sampler must be None if dataset is IterableDataset\"\n    else:\n        if sampler is None:\n            sampler = InferenceSampler(len(dataset))\n    return torchdata.DataLoader(\n        dataset,\n        batch_size=batch_size,\n        sampler=sampler,\n        drop_last=False,\n        num_workers=num_workers,\n        collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,\n    )\n\n\ndef trivial_batch_collator(batch):\n    return batch\n\n\ndef _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):\n    sampler_name = cfg.DATALOADER.SAMPLER_TRAIN\n    if 'MultiDataset' in sampler_name:\n        dataset_dicts = get_detection_dataset_dicts_with_source(\n            cfg.DATASETS.TRAIN,\n            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,\n            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE\n            if cfg.MODEL.KEYPOINT_ON else 0,\n            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,\n        )\n    else:\n        dataset_dicts = get_detection_dataset_dicts(\n            cfg.DATASETS.TRAIN,\n            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,\n            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE\n            if cfg.MODEL.KEYPOINT_ON else 0,\n            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,\n        )\n\n    if mapper is None:\n        mapper = DatasetMapper(cfg, True)\n\n    if sampler is not None:\n        pass\n    elif sampler_name == \"TrainingSampler\":\n        sampler = TrainingSampler(len(dataset))\n    elif sampler_name == \"MultiDatasetSampler\":\n        sampler = MultiDatasetSampler(\n            dataset_dicts,\n            dataset_ratio = cfg.DATALOADER.DATASET_RATIO,\n            use_rfs = cfg.DATALOADER.USE_RFS,\n            dataset_ann = cfg.DATALOADER.DATASET_ANN,\n            repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,\n        )\n    elif sampler_name == \"RepeatFactorTrainingSampler\":\n        repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(\n            dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD\n        )\n        sampler = RepeatFactorTrainingSampler(repeat_factors)\n    else:\n        raise ValueError(\"Unknown training sampler: {}\".format(sampler_name))\n\n    return {\n        \"dataset\": dataset_dicts,\n        \"sampler\": sampler,\n        \"mapper\": mapper,\n        \"total_batch_size\": cfg.SOLVER.IMS_PER_BATCH,\n        \"aspect_ratio_grouping\": cfg.DATALOADER.ASPECT_RATIO_GROUPING,\n        \"num_workers\": cfg.DATALOADER.NUM_WORKERS,\n        'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING,\n        'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,\n        'dataset_bs': cfg.DATALOADER.DATASET_BS,\n        'num_datasets': len(cfg.DATASETS.TRAIN)\n    }\n\n\n@configurable(from_config=_custom_train_loader_from_config)\ndef build_custom_train_loader(\n        dataset, *, mapper, sampler, \n        total_batch_size=16,\n        aspect_ratio_grouping=True, \n        num_workers=0,\n        num_datasets=1,\n        multi_dataset_grouping=False,\n        use_diff_bs_size=False,\n        dataset_bs=[]\n    ):\n    \"\"\"\n    Modified from detectron2.data.build.build_custom_train_loader, but supports\n    different samplers\n    \"\"\"\n    if isinstance(dataset, list):\n        dataset = DatasetFromList(dataset, copy=False)\n    if mapper is not None:\n        dataset = MapDataset(dataset, mapper)\n    if sampler is None:\n        sampler = TrainingSampler(len(dataset))\n    assert isinstance(sampler, torch.utils.data.sampler.Sampler)\n    if multi_dataset_grouping:\n        return build_multi_dataset_batch_data_loader(\n            use_diff_bs_size,\n            dataset_bs,\n            dataset,\n            sampler,\n            total_batch_size,\n            num_datasets=num_datasets,\n            num_workers=num_workers,\n        )\n    else:\n        return build_batch_data_loader(\n            dataset,\n            sampler,\n            total_batch_size,\n            aspect_ratio_grouping=aspect_ratio_grouping,\n            num_workers=num_workers,\n        )\n\n\ndef build_multi_dataset_batch_data_loader(\n    use_diff_bs_size, dataset_bs,\n    dataset, sampler, total_batch_size, num_datasets, num_workers=0\n):\n    \"\"\"\n    \"\"\"\n    world_size = get_world_size()\n    assert (\n        total_batch_size > 0 and total_batch_size % world_size == 0\n    ), \"Total batch size ({}) must be divisible by the number of gpus ({}).\".format(\n        total_batch_size, world_size\n    )\n\n    batch_size = total_batch_size // world_size\n    data_loader = torch.utils.data.DataLoader(\n        dataset,\n        sampler=sampler,\n        num_workers=num_workers,\n        batch_sampler=None,\n        collate_fn=operator.itemgetter(0),  # don't batch, but yield individual elements\n        worker_init_fn=worker_init_reset_seed,\n    )  # yield individual mapped dict\n    if use_diff_bs_size:\n        return DIFFMDAspectRatioGroupedDataset(\n            data_loader, dataset_bs, num_datasets)\n    else:\n        return MDAspectRatioGroupedDataset(\n            data_loader, batch_size, num_datasets)\n\n\ndef get_detection_dataset_dicts_with_source(\n    dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None\n):\n    assert len(dataset_names)\n    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]\n    for dataset_name, dicts in zip(dataset_names, dataset_dicts):\n        assert len(dicts), \"Dataset '{}' is empty!\".format(dataset_name)\n    \n    for source_id, (dataset_name, dicts) in \\\n        enumerate(zip(dataset_names, dataset_dicts)):\n        assert len(dicts), \"Dataset '{}' is empty!\".format(dataset_name)\n        for d in dicts:\n            d['dataset_source'] = source_id\n\n        if \"annotations\" in dicts[0]:\n            try:\n                class_names = MetadataCatalog.get(dataset_name).thing_classes\n                check_metadata_consistency(\"thing_classes\", dataset_name)\n                print_instances_class_histogram(dicts, class_names)\n            except AttributeError:  # class names are not available for this dataset\n                pass\n\n    assert proposal_files is None\n\n    dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))\n\n    has_instances = \"annotations\" in dataset_dicts[0]\n    if filter_empty and has_instances:\n        dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)\n    if min_keypoints > 0 and has_instances:\n        dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)\n\n    return dataset_dicts\n\n\nclass MultiDatasetSampler(Sampler):\n    def __init__(\n        self, \n        dataset_dicts, \n        dataset_ratio,\n        use_rfs,\n        dataset_ann,\n        repeat_threshold=0.001,\n        seed: Optional[int] = None,\n        ):\n        \"\"\"\n        \"\"\"\n        sizes = [0 for _ in range(len(dataset_ratio))]\n        for d in dataset_dicts:\n            sizes[d['dataset_source']] += 1\n        logger.info('dataset sizes {}'.format(sizes))\n        \n        self.sizes = sizes\n        assert len(dataset_ratio) == len(sizes), \\\n            'length of dataset ratio {} should be equal to number if dataset {}'.format(\n                len(dataset_ratio), len(sizes)\n            )\n        if seed is None:\n            seed = comm.shared_random_seed()\n        self._seed = int(seed)\n        self._rank = comm.get_rank()\n        self._world_size = comm.get_world_size()\n        \n        self.dataset_ids = torch.tensor(\n            [d['dataset_source'] for d in dataset_dicts], dtype=torch.long)\n\n        dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \\\n            for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]\n        dataset_weight = torch.cat(dataset_weight)\n\n        rfs_factors = []\n        st = 0\n        for i, s in enumerate(sizes):\n            if use_rfs[i]:\n                if dataset_ann[i] == 'box':\n                    rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency\n                else:\n                    rfs_func = repeat_factors_from_tag_frequency\n                rfs_factor = rfs_func(\n                    dataset_dicts[st: st + s],\n                    repeat_thresh=repeat_threshold)\n                rfs_factor = rfs_factor * (s / rfs_factor.sum())\n            else:\n                rfs_factor = torch.ones(s)\n            rfs_factors.append(rfs_factor)\n            st = st + s\n        rfs_factors = torch.cat(rfs_factors)\n\n        self.weights = dataset_weight * rfs_factors\n        self.sample_epoch_size = len(self.weights)\n\n\n    def __iter__(self):\n        start = self._rank\n        yield from itertools.islice(\n            self._infinite_indices(), start, None, self._world_size)\n\n\n    def _infinite_indices(self):\n        g = torch.Generator()\n        g.manual_seed(self._seed)\n        while True:\n            ids = torch.multinomial(\n                self.weights, self.sample_epoch_size, generator=g, \n                replacement=True)\n            nums = [(self.dataset_ids[ids] == i).sum().int().item() \\\n                for i in range(len(self.sizes))]\n            yield from ids\n\n\nclass MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):\n    def __init__(self, dataset, batch_size, num_datasets):\n        \"\"\"\n        \"\"\"\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self._buckets = [[] for _ in range(2 * num_datasets)]\n\n    def __iter__(self):\n        for d in self.dataset:\n            w, h = d[\"width\"], d[\"height\"]\n            aspect_ratio_bucket_id = 0 if w > h else 1\n            bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id\n            bucket = self._buckets[bucket_id]\n            bucket.append(d)\n            if len(bucket) == self.batch_size:\n                yield bucket[:]\n                del bucket[:]\n\n\nclass DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):\n    def __init__(self, dataset, batch_sizes, num_datasets):\n        \"\"\"\n        \"\"\"\n        self.dataset = dataset\n        self.batch_sizes = batch_sizes\n        self._buckets = [[] for _ in range(2 * num_datasets)]\n\n    def __iter__(self):\n        for d in self.dataset:\n            w, h = d[\"width\"], d[\"height\"]\n            aspect_ratio_bucket_id = 0 if w > h else 1\n            bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id\n            bucket = self._buckets[bucket_id]\n            bucket.append(d)\n            if len(bucket) == self.batch_sizes[d['dataset_source']]:\n                yield bucket[:]\n                del bucket[:]\n\n\ndef repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):\n    \"\"\"\n    \"\"\"\n    category_freq = defaultdict(int)\n    for dataset_dict in dataset_dicts:\n        cat_ids = dataset_dict['pos_category_ids']\n        for cat_id in cat_ids:\n            category_freq[cat_id] += 1\n    num_images = len(dataset_dicts)\n    for k, v in category_freq.items():\n        category_freq[k] = v / num_images\n\n    category_rep = {\n        cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))\n        for cat_id, cat_freq in category_freq.items()\n    }\n\n    rep_factors = []\n    for dataset_dict in dataset_dicts:\n        cat_ids = dataset_dict['pos_category_ids']\n        rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)\n        rep_factors.append(rep_factor)\n\n    return torch.tensor(rep_factors, dtype=torch.float32)\n"
  },
  {
    "path": "datasets_os/dataset_mappers/__init__.py",
    "content": "\nfrom .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper\nfrom .flickr_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as FlickrNewBaselineDatasetMapper\nfrom .coco_instruct_grounding_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInstructGroundingDatasetMapper\nfrom .coco_instruct_grounding_dataset_interactive_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInterGroundingDatasetMapper\nfrom .vg_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as VGNewBaselineDatasetMapper"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\n\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    assert is_train, \"Only support training augmentation\"\n    cfg_input = cfg['INPUT']\n    image_size = cfg_input['IMAGE_SIZE']\n    min_scale = cfg_input['MIN_SCALE']\n    max_scale = cfg_input['MAX_SCALE']\n\n    augmentation = []\n\n    if cfg_input['RANDOM_FLIP'] != \"none\":\n        augmentation.append(\n            T.RandomFlip(\n                horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n                vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n            )\n        )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        if not self.is_train:\n            # USER: Modify this if you want to keep them for some reason.\n            dataset_dict.pop(\"annotations\", None)\n            return dataset_dict\n\n        if \"annotations\" in dataset_dict:\n            # USER: Modify this if you want to keep them for some reason.\n            for anno in dataset_dict[\"annotations\"]:\n                # Let's always keep mask\n                # if not self.mask_on:\n                #     anno.pop(\"segmentation\", None)\n                anno.pop(\"keypoints\", None)\n\n            # USER: Implement additional transformations if you have other types of data\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict.pop(\"annotations\")\n                if obj.get(\"iscrowd\", 0) == 0\n            ]\n            # NOTE: does not support BitMask due to augmentation\n            # Current BitMask cannot handle empty objects\n            instances = utils.annotations_to_instances(annos, image_shape)\n            # After transforms such as cropping are applied, the bounding box may no longer\n            # tightly bound the object. As an example, imagine a triangle object\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n            # the intersection of original bounding box and the cropping box.\n            instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n            # Need to filter empty instances first (due to augmentation)\n            instances = utils.filter_empty_instances(instances)\n            # Generate masks from polygon\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks\n            dataset_dict[\"instances\"] = instances\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_instruct_grounding_dataset_interactive_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\nimport random\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\nfrom llava import conversation as conversation_lib\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\ndef preprocess_multimodal(\n    sources,\n    data_args\n):\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        data_args,\n        preprocess,\n        refcoco=None,\n        max_sampled=5,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n        self.refcoco=refcoco\n        self.max_sampled=max_sampled\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None,refcoco=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n            \"refcoco\":refcoco,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        num_conversations = len(dataset_dict['conversations'])\n        if self.refcoco:\n            max_sampled=min(self.max_sampled,num_conversations)\n            sample_num=random.randint(1,max_sampled)\n            sampled_convs=random.sample(dataset_dict['conversations'], k=sample_num)\n            grounding_list=[]\n            selected_conversation=[]\n            sampled_convs[0][0][0]['value']='<image>\\n'+sampled_convs[0][0][0]['value']\n            for conv,gd in sampled_convs:\n                grounding_list.extend(gd)\n                conv[1]['value']=random.choice(conv[1]['value'])\n                selected_conversation.extend(conv)\n\n        else:\n            rd = np.random.choice(num_conversations)\n            selected_conversation, grounding_list = dataset_dict['conversations'][rd]\n        dataset_dict['conversation'] = [selected_conversation]\n        sources = preprocess_multimodal(\n            copy.deepcopy(dataset_dict['conversation']),\n            self.data_args)\n        data_dict_conversation = self.preprocess(\n            sources,\n            self.tokenizer,\n            has_image=True)\n        data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                                      labels=data_dict_conversation[\"labels\"][0])\n        dataset_dict.update(data_dict_conversation)\n        dataset_dict['tokenizer'] = self.tokenizer\n        # num_segs = sum([conv['value'].count('<seg>') for conv in selected_conversation])\n        # grounding_list=\n        assert \"grounding_info\" in dataset_dict and len(dataset_dict['grounding_info'])>0\n        anno_id2id=dict()\n        for id,obj in enumerate(dataset_dict['grounding_info']):\n            obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n            anno_id2id[obj['id']]=id\n        # id2class=[[] for _ in range(len(dataset_dict['grounding_info']))]\n\n        annos = [\n            utils.transform_instance_annotations(obj, transforms, image_shape)\n            for obj in dataset_dict[\"grounding_info\"]\n        ]\n        # assert  \"segmentation\" in annos[0]\n        instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n\n        h, w = instances.image_size\n        # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n        if hasattr(instances, 'gt_masks'):\n            gt_masks = instances.gt_masks\n            # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n            instances.gt_masks = gt_masks.tensor\n        num_objs=(data_dict_conversation['input_ids']==1273).sum()\n        grounding_list=[gd for gd in grounding_list if gd is not None]\n        merged_grounding_list=sum(grounding_list,[])\n        # assert num_objs==len(merged_grounding_list)\n        if num_objs<len(merged_grounding_list):\n            merged_grounding_list=merged_grounding_list[:num_objs]\n        elif num_objs>len(merged_grounding_list):\n            merged_grounding_list=merged_grounding_list+[merged_grounding_list[-1]]*(num_objs-len(merged_grounding_list))\n        merged_grounding_list=[anno_id2id[annid] for annid in merged_grounding_list]\n        dataset_dict['grounding_index']=merged_grounding_list\n        dataset_dict[\"instances\"] = instances\n            # if grounding_list is None:\n            #     dataset_dict['grounding']=False\n            #     grounding_mask=[False for _ in range(num_segs)]\n            #     dataset_dict['grounding_mask']=grounding_mask\n            # else:\n            #     grounding_mask=[True if g is not None else False for g in grounding_list]\n            #     dataset_dict['grounding_mask']=grounding_mask\n            #     new_grounding_list=[g for g in grounding_list if g is not None]\n            #     if sum(grounding_mask)==0:\n            #         dataset_dict['grounding']=False\n            #     else:\n            #         dataset_dict['grounding']=True\n            # if dataset_dict['grounding']:\n            #     # assert num_segs == len(grounding_list)\n            #     for grounding_id,grounding in enumerate(new_grounding_list):\n            #         if grounding is not None:\n            #             for annid in grounding:\n            #                 id2class[anno_id2id[annid]].append(grounding_id)\n            #\n            #     instances.gt_classes=id2class\n            # dataset_dict[\"instances\"] = instances\n\n        # else:\n        #     dataset_dict['grounding'] = False\n        #     grounding_mask = [False for _ in range(num_segs)]\n        #     dataset_dict['grounding_mask'] = grounding_mask\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_instruct_grounding_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\nfrom llava import conversation as conversation_lib\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\nsuffix=[\n\"Please also provide the boxes and masks for the noun phrases in the response.\"\n, \"Kindly ensure that the response includes the relevant boxes and masks for each noun phrase.\"\n, \"Additionally, include the boxes and masks that match each noun phrase in the response.\"\n, \"Please provide the boxes and masks that correspond to every noun phrase in your response.\"\n, \"It’s important to have the boxes and masks that align with each noun phrase in the response.\"\n, \"Make sure to include the appropriate boxes and masks for each noun phrase in your response.\"\n, \"In your response, include the boxes and masks that pertain to each noun phrase.\"\n, \"Also, supply the boxes and masks that are linked to each noun phrase in the response.\"\n, \"Additionally, please furnish the boxes and masks that correspond to each noun phrase in the response.\"\n, \"Don’t forget to provide the boxes and masks associated with each noun phrase in your response.\"\n, \"Ensure that each noun phrase in the response has its respective boxes and masks.\",\n]\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\ndef preprocess_multimodal(\n    sources,\n    data_args\n):\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        data_args,\n        preprocess,\n        replace_suffix=False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n        self.replace_suffix=replace_suffix\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n            \"replace_suffix\": cfg['MODEL'].get('REPLACE_SUFFIX', False),\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n        num_conversations = len(dataset_dict['conversations'])\n        rd = np.random.choice(num_conversations)\n        selected_conversation, grounding_list = dataset_dict['conversations'][rd]\n        dataset_dict['conversation'] = [selected_conversation]\n        sources = preprocess_multimodal(\n            copy.deepcopy(dataset_dict['conversation']),\n            self.data_args)\n        if self.replace_suffix:\n            for conv in sources[0]:\n                sf=np.random.choice(suffix)\n                if conv['from'] == 'human':\n                    conv['value'] = conv['value'].replace('(with grounding)', sf, 1)\n\n        data_dict_conversation = self.preprocess(\n            sources,\n            self.tokenizer,\n            has_image=True)\n        data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                                      labels=data_dict_conversation[\"labels\"][0])\n        dataset_dict.update(data_dict_conversation)\n        dataset_dict['tokenizer'] = self.tokenizer\n        num_segs = sum([conv['value'].count('<seg>') for conv in selected_conversation])\n        # grounding_list=\n        if \"grounding_info\" in dataset_dict and len(dataset_dict['grounding_info'])>0:\n            anno_id2id=dict()\n            for id,obj in enumerate(dataset_dict['grounding_info']):\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                anno_id2id[obj['id']]=id\n            id2class=[[] for _ in range(len(dataset_dict['grounding_info']))]\n\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n\n            if grounding_list is None:\n                dataset_dict['grounding']=False\n                grounding_mask=[False for _ in range(num_segs)]\n                dataset_dict['grounding_mask']=grounding_mask\n            else:\n                grounding_mask=[True if g is not None else False for g in grounding_list]\n                dataset_dict['grounding_mask']=grounding_mask\n                new_grounding_list=[g for g in grounding_list if g is not None]\n                if sum(grounding_mask)==0:\n                    dataset_dict['grounding']=False\n                else:\n                    dataset_dict['grounding']=True\n            if dataset_dict['grounding']:\n                # assert num_segs == len(grounding_list)\n                for grounding_id,grounding in enumerate(new_grounding_list):\n                    if grounding is not None:\n                        for annid in grounding:\n                            id2class[anno_id2id[annid]].append(grounding_id)\n\n                instances.gt_classes=id2class\n            dataset_dict[\"instances\"] = instances\n\n        else:\n            dataset_dict['grounding'] = False\n            grounding_mask = [False for _ in range(num_segs)]\n            dataset_dict['grounding_mask'] = grounding_mask\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_interactive_panoptic_new_baseline_dataset_mapper.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li.\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.config import configurable\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Boxes, Instances\n\n__all__ = [\"COCOInteractivePanopticNewBaselineDatasetMapper\"]\n\n\ndef filter_empty_instances_by_box(\n        instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False\n):\n    assert by_box or by_mask\n    r = []\n    if by_box:\n        r.append(instances.gt_boxes.nonempty(threshold=box_threshold))\n    if instances.has(\"gt_masks\") and by_mask:\n        r.append(instances.gt_masks.nonempty())\n\n    # TODO: can also filter visible keypoints\n\n    if not r:\n        return instances\n    m = r[0]\n    for x in r[1:]:\n        m = m & x\n    if return_mask:\n        return instances[m], m\n    return instances[m]\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    assert is_train, \"Only support training augmentation\"\n    image_size = cfg.INPUT.IMAGE_SIZE\n    min_scale = cfg.INPUT.MIN_SCALE\n    max_scale = cfg.INPUT.MAX_SCALE\n\n    augmentation = []\n\n    if cfg.INPUT.RANDOM_FLIP != \"none\":\n        augmentation.append(\n            T.RandomFlip(\n                horizontal=cfg.INPUT.RANDOM_FLIP == \"horizontal\",\n                vertical=cfg.INPUT.RANDOM_FLIP == \"vertical\",\n            )\n        )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInteractivePanopticNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            crop_gen: crop augmentation\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(\n                str(self.tfm_gens)\n            )\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n\n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg.INPUT.FORMAT,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        image_shape = image.shape[:2]  # h, w\n        dataset_dict[\"image_ori\"]=image\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        if \"pan_seg_file_name\" in dataset_dict:\n            pan_seg_gt = utils.read_image(dataset_dict.pop(\"pan_seg_file_name\"), \"RGB\")\n            segments_info = dataset_dict[\"segments_info\"]\n\n            # apply the same transformation to panoptic segmentation\n            pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)\n\n            from panopticapi.utils import rgb2id\n\n            pan_seg_gt = rgb2id(pan_seg_gt)\n\n            instances = Instances(image_shape)\n            classes = []\n            masks = []\n            for segment_info in segments_info:\n                class_id = segment_info[\"category_id\"]\n                if not segment_info[\"iscrowd\"]:\n                    classes.append(class_id)\n                    masks.append(pan_seg_gt == segment_info[\"id\"])\n\n            classes = np.array(classes)\n            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)\n            if len(masks) == 0:\n                # Some image does not have annotation (all ignored)\n                instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))\n                instances.gt_boxes = Boxes(torch.zeros((0, 4)))\n            else:\n                masks = BitMasks(\n                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])\n                )\n                instances.gt_masks = masks.tensor\n                instances.gt_boxes = masks.get_bounding_boxes()\n\n            dataset_dict[\"instances\"] = filter_empty_instances_by_box(instances)\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_panoptic_interactive_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\nimport random\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Boxes, Instances, BoxMode\nfrom detectron2.structures.boxes import pairwise_iou\nfrom detectron2.data.datasets.builtin_meta import COCO_CATEGORIES\nfrom detectron2.data import MetadataCatalog\nfrom pycocotools import mask as coco_mask\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\nfrom llava.model.openseed.utils import configurable\n# from ..shapes.sampler import build_shape_sampler\n\n__all__ = [\"COCOPanopticInteractiveDatasetMapper\"]\n\ndef filter_empty_instances_by_box(\n        instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False\n):\n    assert by_box or by_mask\n    r = []\n    if by_box:\n        r.append(instances.gt_boxes.nonempty(threshold=box_threshold))\n    if instances.has(\"gt_masks\") and by_mask:\n        r.append(instances.gt_masks.nonempty())\n\n    # TODO: can also filter visible keypoints\n\n    if not r:\n        return instances\n    m = r[0]\n    for x in r[1:]:\n        m = m & x\n    if return_mask:\n        return instances[m], m\n    return instances[m]\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    # assert is_train, \"Only support training augmentation\"\n    cfg_input = cfg['INPUT']\n    image_size = cfg_input['IMAGE_SIZE']\n    min_scale = cfg_input['MIN_SCALE']\n    max_scale = cfg_input['MAX_SCALE']\n\n    augmentation = []\n\n    # if cfg_input['RANDOM_FLIP'] != \"none\":\n    #     augmentation.append(\n    #         T.RandomFlip(\n    #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n    #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n    #         )\n    #     )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n\n    return augmentation\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOPanopticInteractiveDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n            self,\n            is_train=True,\n            *,\n            tfm_gens,\n            image_format,\n            caption_thres,\n            # lvis,\n            # lvis_thres,\n            max_grounding_num,\n            tokenizer,\n            data_args,\n            preprocess,\n            # shape_sampler,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            crop_gen: crop augmentation\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(\n                str(self.tfm_gens)\n            )\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.caption_thres = caption_thres\n        self.grounding = True\n        # self.lvis = lvis\n        # self.lvis_thres = lvis_thres\n        self.max_grounding_num = max_grounding_num\n        self.caption_similarity = torch.load(MetadataCatalog.get('logistic').get('caption_similarity_pth'))\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n\n        # self.shape_sampler = shape_sampler\n\n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n        # shape_sampler = build_shape_sampler(cfg)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"caption_thres\": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'],\n            \"max_grounding_num\": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n            # \"shape_sampler\": shape_sampler,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        image_shape = image.shape[:2]  # h, w\n        dataset_dict[\"image_ori\"]=image\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n\n\n        if \"pan_seg_file_name\" in dataset_dict:\n            pan_seg_gt = utils.read_image(dataset_dict.pop(\"pan_seg_file_name\"), \"RGB\")\n            segments_info = dataset_dict[\"segments_info\"]\n\n            # apply the same transformation to panoptic segmentation\n            pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)\n\n            from panopticapi.utils import rgb2id\n\n            pan_seg_gt = rgb2id(pan_seg_gt)\n\n            instances = Instances(image_shape)\n            classes = []\n            masks = []\n            for segment_info in segments_info:\n                class_id = segment_info[\"category_id\"]\n                if not segment_info[\"iscrowd\"]:\n                    classes.append(class_id)\n                    masks.append(pan_seg_gt == segment_info[\"id\"])\n\n            # is_things = [COCO_CATEGORIES[idx]['isthing'] for idx in classes]\n            classes = np.array(classes)\n            # is_things = np.array(is_things)\n            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)\n            # instances.is_things = torch.tensor(is_things, dtype=torch.int64)\n\n            if len(masks) == 0:\n                # Some image does not have annotation (all ignored)\n                masks = BitMasks(torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])))\n                instances.gt_masks = masks\n                instances.gt_boxes = Boxes(torch.zeros((0, 4)))\n            else:\n                masks = BitMasks(\n                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])\n                )\n                instances.gt_masks = masks\n                instances.gt_boxes = masks.get_bounding_boxes()\n\n        if self.grounding:\n            grounding_anno = dataset_dict['grounding_info']\n            if self.is_train:\n                grounding_len = random.randint(1, self.max_grounding_num - 1)\n            else:\n                grounding_len = 1\n            if len(grounding_anno) > 0:\n                masks_grd = []\n                texts_grd = []\n                mode = 'text'\n                random.shuffle(grounding_anno)\n                for ann in grounding_anno:\n                    rle = coco_mask.frPyObjects(\n                        ann['segmentation'], dataset_dict['height'], dataset_dict['width'])\n                    m = coco_mask.decode(rle)\n                    # sometimes there are multiple binary map (corresponding to multiple segs)\n                    m = np.sum(m, axis=2)>0\n                    m = m.astype(np.uint8)  # convert to np.uint8\n                    m = transforms.apply_segmentation(m[:, :, None])[:, :, 0]==1\n                    masks_grd += [m]\n                    # random select a sentence of a single annotation.\n                    rand_index = random.randint(0, len(ann['sentences']) - 1)\n                    texts_grd += [ann['sentences'][rand_index]['raw'].lower()]\n                max_len = min(grounding_len, len(texts_grd))\n                indices = np.random.permutation(max_len)\n                texts_grd = list(np.array(texts_grd)[indices])\n                masks_grd = torch.tensor(np.stack(masks_grd)[indices])\n                hash_grd = np.array([hash(txt) for txt in texts_grd])\n                gt_classes = list(range(len(texts_grd)))\n                gt_classes = [[lb] for lb in gt_classes]\n                label_set=texts_grd\n            else:\n                assert self.is_train\n                masks_grd = instances.gt_masks.tensor\n                mode = 'class'\n                assert len(masks_grd) > 0\n\n                texts_grd = np.array([COCO_CATEGORIES[idx]['name'] for idx in classes])\n                hash_grd = np.array([hash(txt) for txt in texts_grd])\n                unique_hash_grd = np.unique(hash_grd)\n                np.random.shuffle(unique_hash_grd)\n                max_len = min(grounding_len,len(unique_hash_grd))\n                indices = np.random.permutation(max_len)\n                selected_unique_hash_grd = unique_hash_grd[indices]\n                selected_mask = np.in1d(hash_grd, selected_unique_hash_grd)\n                texts_grd = texts_grd[selected_mask]\n                hash_grd = hash_grd[selected_mask]\n                masks_grd = masks_grd[selected_mask]\n                texts_grd = [\n                    text.replace('-other', '').replace('-merged', '').replace('-stuff', '')\n                    for text in texts_grd]\n                label_set=list(set(texts_grd))\n                gt_classes=[[label_set.index(lb)] for lb in texts_grd]\n\n            instances_gd = Instances(image_shape)\n            instances_gd.gt_masks = BitMasks(masks_grd)\n            instances_gd.gt_boxes = BitMasks(masks_grd).get_bounding_boxes()\n            instances_gd.gt_masks=instances_gd.gt_masks.tensor\n            instances_gd.gt_classes=gt_classes\n            dataset_dict[\"instances\"] = instances_gd\n            conversations=[]\n            for i in range(len(label_set)):\n                if i==0:\n                    question={'from': 'human', 'value': f\"<image>\\n Please detect the object according to the text {label_set[i]} (referring).\"}\n                else:\n                    question={'from': 'human', 'value': f\"Please detect the object according to the text {label_set[i]} (referring).\"}\n                answer={'from': 'gpt', 'value': '<seg> .'}\n                conversations.append(question)\n                conversations.append(answer)\n\n            dataset_dict['conversation'] = [conversations]\n\n            data_dict_conversation = self.preprocess(\n                dataset_dict['conversation'],\n                self.tokenizer,\n                has_image=True)\n            data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                             labels=data_dict_conversation[\"labels\"][0])\n            dataset_dict.update(data_dict_conversation)\n            dataset_dict['tokenizer']=self.tokenizer\n\n\n        return dataset_dict"
  },
  {
    "path": "datasets_os/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li.\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.config import configurable\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Boxes, Instances\n\n__all__ = [\"COCOPanopticNewBaselineDatasetMapper\"]\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    # assert is_train, \"Only support training augmentation\"\n    image_size = cfg.INPUT.IMAGE_SIZE\n    min_scale = cfg.INPUT.MIN_SCALE\n    max_scale = cfg.INPUT.MAX_SCALE\n\n    augmentation = []\n\n    # if cfg.INPUT.RANDOM_FLIP != \"none\":\n    #     augmentation.append(\n    #         T.RandomFlip(\n    #             horizontal=cfg.INPUT.RANDOM_FLIP == \"horizontal\",\n    #             vertical=cfg.INPUT.RANDOM_FLIP == \"vertical\",\n    #         )\n    #     )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOPanopticNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            crop_gen: crop augmentation\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(\n                str(self.tfm_gens)\n            )\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n\n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg.INPUT.FORMAT,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        if \"pan_seg_file_name\" in dataset_dict:\n            pan_seg_gt = utils.read_image(dataset_dict.pop(\"pan_seg_file_name\"), \"RGB\")\n            segments_info = dataset_dict[\"segments_info\"]\n\n            # apply the same transformation to panoptic segmentation\n            pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)\n\n            from panopticapi.utils import rgb2id\n\n            pan_seg_gt = rgb2id(pan_seg_gt)\n\n            instances = Instances(image_shape)\n            classes = []\n            masks = []\n            for segment_info in segments_info:\n                class_id = segment_info[\"category_id\"]\n                if not segment_info[\"iscrowd\"]:\n                    classes.append(class_id)\n                    masks.append(pan_seg_gt == segment_info[\"id\"])\n\n            classes = np.array(classes)\n            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)\n            if len(masks) == 0:\n                # Some image does not have annotation (all ignored)\n                instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))\n                instances.gt_boxes = Boxes(torch.zeros((0, 4)))\n            else:\n                masks = BitMasks(\n                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])\n                )\n                instances.gt_masks = masks.tensor\n                instances.gt_boxes = masks.get_bounding_boxes()\n\n            dataset_dict[\"instances\"] = instances\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        data_args,\n        preprocess,\n        gd_mode=\"inter\",\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n        self.gd_mode= gd_mode\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n            \"gd_mode\": cfg.flickr.get(\"gd_mode\", \"inter\"),\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        if \"grounding_info\" in dataset_dict:\n\n            for obj in dataset_dict['grounding_info']:\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]]\n            # USER: Implement additional transformations if you have other types of data\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # NOTE: does not support BitMask due to augmentation\n            # Current BitMask cannot handle empty objects\n            assert len(annos)>0\n            assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n            # After transforms such as cropping are applied, the bounding box may no longer\n            # tightly bound the object. As an example, imagine a triangle object\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n            # the intersection of original bounding box and the cropping box.\n            # instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n            # Need to filter empty instances first (due to augmentation)\n            # instances = utils.filter_empty_instances(instances)\n            # Generate masks from polygon\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n\n            span_set = dict()\n            end_dict = dict()\n            gt_classes= []\n            for i, info in enumerate(dataset_dict['grounding_info']):\n                gt_classes.append([])\n                # if len(info['tokens_positive'])>1:\n                #     print(\"multi class\")\n                for j in range(len(info['tokens_positive'])):\n                    if info['tokens_positive'][j][0] in span_set:\n                        span_set[info['tokens_positive'][j][0]].append(i)\n                    else:\n                        span_set[info['tokens_positive'][j][0]] = [i]\n                    if info['tokens_positive'][j][0] in end_dict:\n                        assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1]\n                    else:\n                        end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1]\n                    gt_classes[-1].append(info['tokens_positive'][j][0])\n\n            end_dict = sorted(end_dict.items())\n            start2id = dict()\n            for i, (s, e) in enumerate(end_dict):\n                start2id[s] = i\n            gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes]\n            instances.gt_classes = gt_classes\n            dataset_dict[\"instances\"] = instances\n            # span_list = sorted(span_set.items())\n\n            # for k, v in span_set:\n            #     for i in range(len(v)):\n            #         v[i] = positive_new_ids[v[i]]\n            cap_pieces = []\n            last_e = 0\n            for s, e in end_dict:\n                cap_pieces.append(dataset_dict['caption'][last_e:s])\n                cap_pieces.append(dataset_dict['caption'][s:e])\n                last_e = e\n            cap_pieces.append(dataset_dict['caption'][last_e:])\n            new_cap = []\n            if 'end' in self.gd_mode:\n                k=1\n                for i, piece in enumerate(cap_pieces):\n                    if i % 2 == 1:\n                        if self.gd_mode == 'end':\n                            piece = '<g_s>' + piece + '<g_e>'\n                        else:\n                            assert self.gd_mode == 'end_num'\n                            piece = f'<g_s> {k} ' + piece + '<g_e>'\n                            k+=1\n                    new_cap.append(piece)\n                new_cap = \"\".join(new_cap)\n                tail = [f'{i + 1}: <seg>' for i in range(new_cap.count(\"<g_s>\"))]\n                tail = '; '.join(tail)\n                new_cap += f' {tail}.'\n            else:\n                for i, piece in enumerate(cap_pieces):\n                    if i % 2 == 1:\n                        piece = '<g_s>' + piece + '<g_e><seg>'\n                    new_cap.append(piece)\n                new_cap = \"\".join(new_cap)\n            # gt_ids = []\n            # for s, e in end_dict:\n            #     if len(span_set[s]) > 1:\n            #         return dataset_dict\n            #     gt_ids.append(span_set[s][0] + 1)\n            # ground_annos = dict()\n            # ground_annos['gt_ids'] = gt_ids\n            # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids]\n            # ground_annos['caption'] = new_cap\n            question={'from': 'human', 'value': \"<image>\\nPresent a compact description of the photo's key features.\\n(with grounding)\"}\n            answer={'from': 'gpt', 'value': new_cap}\n            dataset_dict['conversation'] = [[question, answer]]\n            # sources = preprocess_multimodal(\n            #     copy.deepcopy(dataset_dict['conversation']),\n            #     self.data_args)\n            data_dict_conversation = self.preprocess(\n                dataset_dict['conversation'],\n                self.tokenizer,\n                has_image=True)\n            data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                             labels=data_dict_conversation[\"labels\"][0])\n            dataset_dict.update(data_dict_conversation)\n            dataset_dict['tokenizer']=self.tokenizer\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\nfrom llava import conversation as conversation_lib\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\ndef preprocess_multimodal(\n    sources,\n    data_args\n):\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        data_args,\n        preprocess,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        if \"grounding_info\" in dataset_dict:\n            anno_id2id=dict()\n            for id,obj in enumerate(dataset_dict['grounding_info']):\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                anno_id2id[obj['id']]=id\n            id2class=[[] for _ in range(len(dataset_dict['grounding_info']))]\n\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # NOTE: does not support BitMask due to augmentation\n            # Current BitMask cannot handle empty objects\n            assert len(annos)>0\n            assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n\n            num_conversations = len(dataset_dict['conversations'])\n            rd = np.random.choice(num_conversations)\n            selected_conversation, grounding_list=dataset_dict['conversations'][rd]\n\n            if grounding_list is None:\n                dataset_dict['grounding']=False\n            else:\n                non_none=[1 for g in grounding_list if g is not None]\n                if len(non_none)==0:\n                    dataset_dict['grounding']=False\n                else:\n                    dataset_dict['grounding']=True\n            if dataset_dict['grounding']:\n                num_segs = sum([conv['value'].count('<seg>') for conv in selected_conversation])\n                assert num_segs == len(grounding_list)\n                for grounding_id,grounding in enumerate(grounding_list):\n                    if grounding is not None:\n                        for annid in grounding:\n                            id2class[anno_id2id[annid]].append(grounding_id)\n\n                instances.gt_classes=id2class\n            dataset_dict[\"instances\"] = instances\n\n            dataset_dict['conversation'] = [selected_conversation]\n            sources = preprocess_multimodal(\n                copy.deepcopy(dataset_dict['conversation']),\n                self.data_args)\n            data_dict_conversation = self.preprocess(\n                sources,\n                self.tokenizer,\n                has_image=True)\n            data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                             labels=data_dict_conversation[\"labels\"][0])\n            dataset_dict.update(data_dict_conversation)\n            dataset_dict['tokenizer']=self.tokenizer\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_end.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        data_args,\n        preprocess,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        if \"grounding_info\" in dataset_dict:\n\n            for obj in dataset_dict['grounding_info']:\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]]\n            # USER: Implement additional transformations if you have other types of data\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # NOTE: does not support BitMask due to augmentation\n            # Current BitMask cannot handle empty objects\n            assert len(annos)>0\n            assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n            # After transforms such as cropping are applied, the bounding box may no longer\n            # tightly bound the object. As an example, imagine a triangle object\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n            # the intersection of original bounding box and the cropping box.\n            # instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n            # Need to filter empty instances first (due to augmentation)\n            # instances = utils.filter_empty_instances(instances)\n            # Generate masks from polygon\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n\n            span_set = dict()\n            end_dict = dict()\n            gt_classes= []\n            for i, info in enumerate(dataset_dict['grounding_info']):\n                gt_classes.append([])\n                # if len(info['tokens_positive'])>1:\n                #     print(\"multi class\")\n                for j in range(len(info['tokens_positive'])):\n                    if info['tokens_positive'][j][0] in span_set:\n                        span_set[info['tokens_positive'][j][0]].append(i)\n                    else:\n                        span_set[info['tokens_positive'][j][0]] = [i]\n                    if info['tokens_positive'][j][0] in end_dict:\n                        assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1]\n                    else:\n                        end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1]\n                    gt_classes[-1].append(info['tokens_positive'][j][0])\n\n            end_dict = sorted(end_dict.items())\n            start2id = dict()\n            for i, (s, e) in enumerate(end_dict):\n                start2id[s] = i\n            gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes]\n            instances.gt_classes = gt_classes\n            dataset_dict[\"instances\"] = instances\n            # span_list = sorted(span_set.items())\n\n            # for k, v in span_set:\n            #     for i in range(len(v)):\n            #         v[i] = positive_new_ids[v[i]]\n            cap_pieces = []\n            last_e = 0\n            for s, e in end_dict:\n                cap_pieces.append(dataset_dict['caption'][last_e:s])\n                cap_pieces.append(dataset_dict['caption'][s:e])\n                last_e = e\n            cap_pieces.append(dataset_dict['caption'][last_e:])\n            new_cap = []\n            for i, piece in enumerate(cap_pieces):\n                if i % 2 == 1:\n                    piece = '<g_s>' + piece + '<g_e>'\n                new_cap.append(piece)\n            new_cap = \"\".join(new_cap)\n            tail = [f'{i + 1}: <seg>' for i in range(new_cap.count(\"<g_s>\"))]\n            tail = '; '.join(tail)\n            new_cap += f' {tail}.'\n            # gt_ids = []\n            # for s, e in end_dict:\n            #     if len(span_set[s]) > 1:\n            #         return dataset_dict\n            #     gt_ids.append(span_set[s][0] + 1)\n            # ground_annos = dict()\n            # ground_annos['gt_ids'] = gt_ids\n            # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids]\n            # ground_annos['caption'] = new_cap\n            question={'from': 'human', 'value': \"<image>\\nPresent a compact description of the photo's key features.\\n(with grounding)\"}\n            answer={'from': 'gpt', 'value': new_cap}\n            dataset_dict['conversation'] = [[question, answer]]\n            # sources = preprocess_multimodal(\n            #     copy.deepcopy(dataset_dict['conversation']),\n            #     self.data_args)\n            data_dict_conversation = self.preprocess(\n                dataset_dict['conversation'],\n                self.tokenizer,\n                has_image=True)\n            data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                             labels=data_dict_conversation[\"labels\"][0])\n            dataset_dict.update(data_dict_conversation)\n            dataset_dict['tokenizer']=self.tokenizer\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/flickr_new_baseline_dataset_mapper.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li.\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.config import configurable\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Boxes, Instances\n\n__all__ = [\"COCOInteractivePanopticNewBaselineDatasetMapper\"]\n\n\ndef filter_empty_instances_by_box(\n        instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False\n):\n    assert by_box or by_mask\n    r = []\n    if by_box:\n        r.append(instances.gt_boxes.nonempty(threshold=box_threshold))\n    if instances.has(\"gt_masks\") and by_mask:\n        r.append(instances.gt_masks.nonempty())\n\n    # TODO: can also filter visible keypoints\n\n    if not r:\n        return instances\n    m = r[0]\n    for x in r[1:]:\n        m = m & x\n    if return_mask:\n        return instances[m], m\n    return instances[m]\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    assert is_train, \"Only support training augmentation\"\n    image_size = cfg.INPUT.IMAGE_SIZE\n    min_scale = cfg.INPUT.MIN_SCALE\n    max_scale = cfg.INPUT.MAX_SCALE\n\n    augmentation = []\n\n    if cfg.INPUT.RANDOM_FLIP != \"none\":\n        augmentation.append(\n            T.RandomFlip(\n                horizontal=cfg.INPUT.RANDOM_FLIP == \"horizontal\",\n                vertical=cfg.INPUT.RANDOM_FLIP == \"vertical\",\n            )\n        )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInteractivePanopticNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            crop_gen: crop augmentation\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(\n                str(self.tfm_gens)\n            )\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n\n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg.INPUT.FORMAT,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        image_shape = image.shape[:2]  # h, w\n        dataset_dict[\"image_ori\"]=image\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n\n        if not self.is_train:\n            # USER: Modify this if you want to keep them for some reason.\n            dataset_dict.pop(\"annotations\", None)\n            return dataset_dict\n\n        if \"pan_seg_file_name\" in dataset_dict:\n            pan_seg_gt = utils.read_image(dataset_dict.pop(\"pan_seg_file_name\"), \"RGB\")\n            segments_info = dataset_dict[\"segments_info\"]\n\n            # apply the same transformation to panoptic segmentation\n            pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)\n\n            from panopticapi.utils import rgb2id\n\n            pan_seg_gt = rgb2id(pan_seg_gt)\n\n            instances = Instances(image_shape)\n            classes = []\n            masks = []\n            for segment_info in segments_info:\n                class_id = segment_info[\"category_id\"]\n                if not segment_info[\"iscrowd\"]:\n                    classes.append(class_id)\n                    masks.append(pan_seg_gt == segment_info[\"id\"])\n\n            classes = np.array(classes)\n            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)\n            if len(masks) == 0:\n                # Some image does not have annotation (all ignored)\n                instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))\n                instances.gt_boxes = Boxes(torch.zeros((0, 4)))\n            else:\n                masks = BitMasks(\n                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])\n                )\n                instances.gt_masks = masks.tensor\n                instances.gt_boxes = masks.get_bounding_boxes()\n\n            dataset_dict[\"instances\"] = filter_empty_instances_by_box(instances)\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/inference_mapper_with_gt.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport copy\nimport logging\nimport numpy as np\nfrom typing import List, Optional, Union\nimport torch\n\nfrom detectron2.config import configurable\n\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.structures import BitMasks, Boxes, Instances\n\n\"\"\"\nThis file contains the default mapping that's applied to \"dataset dicts\".\n\"\"\"\n\n__all__ = [\"CoCoInferenceDatasetMapper\"]\n\n\nclass CoCoInferenceDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by the model.\n\n    This is the default callable to be used to map your dataset dict into training data.\n    You may need to follow it to implement your own one for customized logic,\n    such as a different way to read or transform images.\n    See :doc:`/tutorials/data_loading` for details.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies cropping/geometric transforms to the image and annotations\n    3. Prepare data and annotations to Tensor and :class:`Instances`\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train: bool,\n        *,\n        augmentations: List[Union[T.Augmentation, T.Transform]],\n        image_format: str,\n        use_instance_mask: bool = False,\n        use_keypoint: bool = False,\n        instance_mask_format: str = \"polygon\",\n        keypoint_hflip_indices: Optional[np.ndarray] = None,\n        precomputed_proposal_topk: Optional[int] = None,\n        recompute_boxes: bool = False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n\n        Args:\n            is_train: whether it's used in training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n            use_instance_mask: whether to process instance segmentation annotations, if available\n            use_keypoint: whether to process keypoint annotations if available\n            instance_mask_format: one of \"polygon\" or \"bitmask\". Process instance segmentation\n                masks into this format.\n            keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`\n            precomputed_proposal_topk: if given, will load pre-computed\n                proposals from dataset_dict and keep the top k proposals for each image.\n            recompute_boxes: whether to overwrite bounding box annotations\n                by computing tight bounding boxes from instance mask annotations.\n        \"\"\"\n        if recompute_boxes:\n            assert use_instance_mask, \"recompute_boxes requires instance masks\"\n        # fmt: off\n        self.is_train               = is_train\n        self.augmentations          = T.AugmentationList(augmentations)\n        self.image_format           = image_format\n        self.use_instance_mask      = use_instance_mask\n        self.instance_mask_format   = instance_mask_format\n        self.use_keypoint           = use_keypoint\n        self.keypoint_hflip_indices = keypoint_hflip_indices\n        self.proposal_topk          = precomputed_proposal_topk\n        self.recompute_boxes        = recompute_boxes\n        # fmt: on\n        logger = logging.getLogger(__name__)\n        mode = \"training\" if is_train else \"inference\"\n        logger.info(f\"[DatasetMapper] Augmentations used in {mode}: {augmentations}\")\n\n    @classmethod\n    def from_config(cls, cfg, is_train: bool = True):\n        augs = utils.build_augmentation(cfg, is_train)\n        if cfg.INPUT.CROP.ENABLED and is_train:\n            augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))\n            recompute_boxes = cfg.MODEL.MASK_ON\n        else:\n            recompute_boxes = False\n\n        ret = {\n            \"is_train\": is_train,\n            \"augmentations\": augs,\n            \"image_format\": cfg.INPUT.FORMAT,\n            \"use_instance_mask\": cfg.MODEL.MASK_ON,\n            \"instance_mask_format\": cfg.INPUT.MASK_FORMAT,\n            \"use_keypoint\": cfg.MODEL.KEYPOINT_ON,\n            \"recompute_boxes\": recompute_boxes,\n        }\n\n        if cfg.MODEL.KEYPOINT_ON:\n            ret[\"keypoint_hflip_indices\"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)\n\n        if cfg.MODEL.LOAD_PROPOSALS:\n            ret[\"precomputed_proposal_topk\"] = (\n                cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN\n                if is_train\n                else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST\n            )\n        return ret\n\n    def _transform_annotations(self, dataset_dict, transforms, image_shape):\n        # USER: Modify this if you want to keep them for some reason.\n        for anno in dataset_dict[\"annotations\"]:\n            if not self.use_instance_mask:\n                anno.pop(\"segmentation\", None)\n            if not self.use_keypoint:\n                anno.pop(\"keypoints\", None)\n\n        # USER: Implement additional transformations if you have other types of data\n        annos = [\n            utils.transform_instance_annotations(\n                obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices\n            )\n            for obj in dataset_dict.pop(\"annotations\")\n            if obj.get(\"iscrowd\", 0) == 0\n        ]\n        instances = utils.annotations_to_instances(\n            annos, image_shape, mask_format=self.instance_mask_format\n        )\n\n        # After transforms such as cropping are applied, the bounding box may no longer\n        # tightly bound the object. As an example, imagine a triangle object\n        # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n        # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n        # the intersection of original bounding box and the cropping box.\n        if self.recompute_boxes:\n            instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n        dataset_dict[\"instances\"] = utils.filter_empty_instances(instances)\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        # USER: Write your own image loading if it's not from a file\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.image_format)\n        utils.check_image_size(dataset_dict, image)\n\n        # USER: Remove if you don't do semantic/panoptic segmentation.\n        if \"sem_seg_file_name\" in dataset_dict:\n            sem_seg_gt = utils.read_image(dataset_dict.pop(\"sem_seg_file_name\"), \"L\").squeeze(2)\n        else:\n            sem_seg_gt = None\n\n        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)\n        transforms = self.augmentations(aug_input)\n        image, sem_seg_gt = aug_input.image, aug_input.sem_seg\n\n        image_shape = image.shape[:2]  # h, w\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        if sem_seg_gt is not None:\n            dataset_dict[\"sem_seg\"] = torch.as_tensor(sem_seg_gt.astype(\"long\"))\n\n        # USER: Remove if you don't use pre-computed proposals.\n        # Most users would not need this feature.\n        if self.proposal_topk is not None:\n            utils.transform_proposals(\n                dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk\n            )\n\n        # if not self.is_train:\n            # USER: Modify this if you want to keep them for some reason.\n        # dataset_dict.pop(\"annotations\", None)\n        # dataset_dict.pop(\"sem_seg_file_name\", None)\n        # return dataset_dict\n        if \"pan_seg_file_name\" in dataset_dict:\n            pan_seg_gt = utils.read_image(dataset_dict.pop(\"pan_seg_file_name\"), \"RGB\")\n            segments_info = dataset_dict[\"segments_info\"]\n\n            # apply the same transformation to panoptic segmentation\n            pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)\n\n            from panopticapi.utils import rgb2id\n\n            pan_seg_gt = rgb2id(pan_seg_gt)\n\n            instances = Instances(image_shape)\n            classes = []\n            masks = []\n            for segment_info in segments_info:\n                class_id = segment_info[\"category_id\"]\n                if not segment_info[\"iscrowd\"]:\n                    classes.append(class_id)\n                    masks.append(pan_seg_gt == segment_info[\"id\"])\n\n            classes = np.array(classes)\n            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)\n            if len(masks) == 0:\n                # Some image does not have annotation (all ignored)\n                instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))\n                instances.gt_boxes = Boxes(torch.zeros((0, 4)))\n            else:\n                masks = BitMasks(\n                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])\n                )\n                instances.gt_masks = masks.tensor\n                instances.gt_boxes = masks.get_bounding_boxes()\n\n            dataset_dict[\"instances\"] = instances\n            # dataset_dict[\"instances\"] = filter_empty_instances_by_box(instances)\n\n        if \"annotations\" in dataset_dict:\n            self._transform_annotations(dataset_dict, transforms, image_shape)\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/llava_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\n\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n        dataset_dict[\"image_ori\"]=image\n\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        if \"grounding_info\" in dataset_dict:\n\n            for obj in dataset_dict['grounding_info']:\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]]\n            # USER: Implement additional transformations if you have other types of data\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # NOTE: does not support BitMask due to augmentation\n            # Current BitMask cannot handle empty objects\n            assert len(annos)>0\n            assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n            # After transforms such as cropping are applied, the bounding box may no longer\n            # tightly bound the object. As an example, imagine a triangle object\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n            # the intersection of original bounding box and the cropping box.\n            instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n            # Need to filter empty instances first (due to augmentation)\n            instances = utils.filter_empty_instances(instances)\n            # Generate masks from polygon\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n            dataset_dict[\"instances\"] = instances\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/dataset_mappers/refcoco_dataset_mapper.py",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n# Copyright (c) Facebook, Inc. and its affiliates.\nimport copy\nimport random\n\nimport scipy.io\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom torchvision import transforms\n\nfrom pycocotools import mask\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\n\nfrom llava.model.openseed.utils import configurable\n\n__all__ = [\"RefCOCODatasetMapper\"]\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    assert is_train, \"Only support training augmentation\"\n    cfg_input = cfg['INPUT']\n    image_size = cfg_input['IMAGE_SIZE']\n    min_scale = cfg_input['MIN_SCALE']\n    max_scale = cfg_input['MAX_SCALE']\n\n    augmentation = []\n\n\n    if cfg_input['RANDOM_FLIP'] != \"none\":\n        augmentation.append(\n            T.RandomFlip(\n                horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n                vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n            )\n        )\n\n    augmentation.extend([\n        T.ResizeScale(\n            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n        ),\n        T.FixedSizeCrop(crop_size=(image_size, image_size)),\n    ])\n    \n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass RefCOCODatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        tfm_gens=None,\n        image_format=None,\n        min_size_test=None,\n        max_size_test=None,\n        mean=None,\n        std=None,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        self.img_format = image_format\n        self.is_train = is_train\n        self.min_size_test = min_size_test\n        self.max_size_test = max_size_test\n        self.pixel_mean = torch.tensor(mean)[:,None,None]\n        self.pixel_std = torch.tensor(std)[:,None,None]\n\n        t = []\n        t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))\n        self.transform = transforms.Compose(t)\n\n    @classmethod\n    def from_config(cls, cfg, is_train=True):\n        # Build augmentation\n        if is_train:\n            tfm_gens = build_transform_gen(cfg, is_train)\n        else:\n            tfm_gens = None\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT'].get('FORMAT', 'RGB'),\n            \"min_size_test\": cfg['INPUT']['MIN_SIZE_TEST'],\n            \"max_size_test\": cfg['INPUT']['MAX_SIZE_TEST'],\n            \"mean\": cfg['INPUT']['PIXEL_MEAN'],\n            \"std\": cfg['INPUT']['PIXEL_STD'],\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        file_name = dataset_dict['file_name']\n        if self.is_train == False:\n            image = Image.open(file_name).convert('RGB')\n            dataset_dict['width'] = image.size[0]\n            dataset_dict['height'] = image.size[1]\n            image = self.transform(image)\n            image = torch.from_numpy(np.asarray(image).copy())\n            dataset_dict[\"image_ori\"] = image\n            image = image.permute(2,0,1)\n            dataset_dict['image'] = image\n\n            grounding_anno = dataset_dict['grounding_info']\n            assert len(grounding_anno) > 0\n            masks_grd = []\n            texts_grd = []\n            boxes_grd = []\n            for ann in grounding_anno:\n                rle = mask.frPyObjects(\n                    ann['segmentation'], dataset_dict['height'], dataset_dict['width'])\n                m = mask.decode(rle)\n                # sometimes there are multiple binary map (corresponding to multiple segs)\n                m = np.sum(m, axis=2)\n                m = m.astype(np.uint8)  # convert to np.uint8\n                masks_grd += [m]\n                texts_grd.append([x['raw'].lower() for x in ann['sentences']])\n                boxes_grd.append(ann['bbox']) # xywh\n            masks_grd = torch.from_numpy(np.stack(masks_grd))\n            boxes_grd = torch.tensor(boxes_grd)\n\n            groundings = {'masks': masks_grd, 'texts': texts_grd, 'boxes': boxes_grd}\n            dataset_dict[\"groundings\"] = groundings\n        else:\n            image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n            utils.check_image_size(dataset_dict, image)\n            image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n            dataset_dict[\"image_ori\"] = image\n            image_shape = image.shape[:2]  # h, w\n            dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n\n            grounding_anno = dataset_dict['grounding_info']\n            assert len(grounding_anno) > 0\n            masks_grd = []\n            texts_grd = []\n            boxes_grd = []\n            hash_grd = []\n            for ann in grounding_anno:\n                rle = mask.frPyObjects(\n                    ann['segmentation'], dataset_dict['height'], dataset_dict['width'])\n                m = mask.decode(rle)\n                # sometimes there are multiple binary map (corresponding to multiple segs)\n                m = np.sum(m, axis=2)\n                m = m.astype(np.uint8)  # convert to np.uint8\n                m = transforms.apply_segmentation(m[:,:,None])[:,:,0]\n                masks_grd += [m]\n                rand_id = random.randint(0, len(ann['sentences'])-1)\n                texts_grd.append(ann['sentences'][rand_id]['raw'].lower())\n                hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower()))\n            masks_grd = torch.from_numpy(np.stack(masks_grd))\n            boxes_grd = torch.tensor(boxes_grd)\n            groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'}\n            dataset_dict[\"groundings\"] = groundings\n        return dataset_dict"
  },
  {
    "path": "datasets_os/dataset_mappers/vg_instance_new_baseline_dataset_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\n\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\n\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        max_grounding_num,\n        tokenizer,\n        data_args,\n        preprocess,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.max_grounding_num = max_grounding_num\n        self.tokenizer = tokenizer\n        self.processor = data_args.image_processor\n        self.data_args = data_args\n        self.preprocess = preprocess\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"max_grounding_num\": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],\n            \"tokenizer\": tokenizer,\n            \"data_args\": data_args,\n            \"preprocess\": preprocess,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n        #########llava image processing\n\n        if self.data_args.image_aspect_ratio == 'pad':\n            def expand2square(pil_img, background_color):\n                width, height = pil_img.size\n                if width == height:\n                    return pil_img\n                elif width > height:\n                    result = Image.new(pil_img.mode, (width, width), background_color)\n                    result.paste(pil_img, (0, (width - height) // 2))\n                    return result\n                else:\n                    result = Image.new(pil_img.mode, (height, height), background_color)\n                    result.paste(pil_img, ((height - width) // 2, 0))\n                    return result\n\n            image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))\n            image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0]\n        else:\n            image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n        dataset_dict[\"image_ori\"]=image\n        image_shape = image.shape[:2]  # h, w\n\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n\n        # if not self.is_train:\n        #     # USER: Modify this if you want to keep them for some reason.\n        #     dataset_dict.pop(\"annotations\", None)\n        #     return dataset_dict\n\n        assert \"annotations\" in dataset_dict\n\n        for obj in dataset_dict['annotations']:\n            obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n        # USER: Implement additional transformations if you have other types of data\n        annos = [\n            utils.transform_instance_annotations(obj, transforms, image_shape)\n            for obj in dataset_dict[\"annotations\"]\n        ]\n        # NOTE: does not support BitMask due to augmentation\n        # Current BitMask cannot handle empty objects\n        assert len(annos)>0\n        # assert  \"segmentation\" in annos[0]\n        instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n        instances.captions=[ann['caption'] for ann in dataset_dict[\"annotations\"]]\n        # After transforms such as cropping are applied, the bounding box may no longer\n        # tightly bound the object. As an example, imagine a triangle object\n        # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n        # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n        # the intersection of original bounding box and the cropping box.\n        # instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n        # Need to filter empty instances first (due to augmentation)\n        # instances = utils.filter_empty_instances(instances)\n        # Generate masks from polygon\n        h, w = instances.image_size\n        # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n        if hasattr(instances, 'gt_masks'):\n            gt_masks = instances.gt_masks\n            # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n            instances.gt_masks = gt_masks.tensor\n        # dataset_dict[\"instances\"] = instances\n        num_instances = len(instances)\n        indices = list(range(num_instances))\n        import random\n        if self.is_train:\n            grounding_len = random.randint(1, self.max_grounding_num - 1)\n        else:\n            grounding_len = 1\n        random.shuffle(indices)\n        indices = indices[:grounding_len]\n        texts_grd = [instances.captions[i] for i in indices]\n        gt_classes = list(range(len(texts_grd)))\n        gt_classes = [[lb] for lb in gt_classes]\n        label_set = texts_grd\n        grounding_instances = Instances(image_size=(h, w))\n        grounding_instances.gt_boxes = instances.gt_boxes[indices]\n        grounding_instances.gt_classes = gt_classes\n        dataset_dict[\"instances\"]=grounding_instances\n        conversations=[]\n        for i in range(len(label_set)):\n            if i==0:\n                question={'from': 'human', 'value': f\"<image>\\n Please detect the object according to the text {label_set[i]} (referring).\"}\n            else:\n                question={'from': 'human', 'value': f\"Please detect the object according to the text {label_set[i]} (referring).\"}\n            answer={'from': 'gpt', 'value': '<seg> .'}\n            conversations.append(question)\n            conversations.append(answer)\n\n        dataset_dict['conversation'] = [conversations]\n\n        data_dict_conversation = self.preprocess(\n            dataset_dict['conversation'],\n            self.tokenizer,\n            has_image=True)\n        data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                         labels=data_dict_conversation[\"labels\"][0])\n        dataset_dict.update(data_dict_conversation)\n        dataset_dict['tokenizer']=self.tokenizer\n\n        return dataset_dict\n"
  },
  {
    "path": "datasets_os/refer.py",
    "content": "__author__ = 'licheng'\n\n\"\"\"\nThis interface provides access to four datasets:\n1) refclef\n2) refcoco\n3) refcoco+\n4) refcocog\nsplit by unc and google\n\nThe following API functions are defined:\nREFER      - REFER api class\ngetRefIds  - get ref ids that satisfy given filter conditions.\ngetAnnIds  - get ann ids that satisfy given filter conditions.\ngetImgIds  - get image ids that satisfy given filter conditions.\ngetCatIds  - get category ids that satisfy given filter conditions.\nloadRefs   - load refs with the specified ref ids.\nloadAnns   - load anns with the specified ann ids.\nloadImgs   - load images with the specified image ids.\nloadCats   - load category names with the specified category ids.\ngetRefBox  - get ref's bounding box [x, y, w, h] given the ref_id\nshowRef    - show image, segmentation or box of the referred object with the ref\ngetMask    - get mask and area of the referred object given ref\nshowMask   - show mask of the referred object given ref\n\"\"\"\n\nfrom doctest import REPORT_ONLY_FIRST_FAILURE\nimport sys\nimport os.path as osp\nimport json\nimport pickle\nimport time\nimport itertools\nimport skimage.io as io\nimport matplotlib.pyplot as plt\nfrom matplotlib.collections import PatchCollection\nfrom matplotlib.patches import Polygon, Rectangle\nfrom pprint import pprint\nimport numpy as np\nfrom pycocotools import mask\n# import cv2\n# from skimage.measure import label, regionprops\n\n\nclass REFER:\n    def __init__(self, data_root, dataset='refcoco', splitBy='unc'):\n        # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog\n        # also provide dataset name and splitBy information\n        # e.g., dataset = 'refcoco', splitBy = 'unc'\n        print('loading dataset {} into memory...'.format(dataset))\n        self.ROOT_DIR = osp.abspath(osp.dirname(__file__))\n        self.DATA_DIR = osp.join(data_root, dataset)\n        if dataset in ['refcoco', 'refcoco+', 'refcocog']:\n            self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')\n        elif dataset == 'refclef':\n            self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')\n        else:\n            print('No refer dataset is called [{}]'.format(dataset))\n            sys.exit()\n\n        # load refs from data/dataset/refs(dataset).json\n        tic = time.time()\n        ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')        \n        self.data = {}\n        self.data['dataset'] = dataset\n        self.data['refs'] = pickle.load(open(ref_file, 'rb'))\n\n        # load annotations from data/dataset/instances.json\n        instances_file = osp.join(self.DATA_DIR, 'instances.json')\n        instances = json.load(open(instances_file, 'r'))\n        self.data['images'] = instances['images']\n        self.data['annotations'] = instances['annotations']\n        self.data['categories'] = instances['categories']\n\n        # create index\n        self.createIndex()\n        print('DONE (t=%.2fs)'.format(time.time()-tic))\n\n    def createIndex(self):\n        # create sets of mapping\n        # 1)  Refs: \t \t{ref_id: ref}\n        # 2)  Anns: \t \t{ann_id: ann}\n        # 3)  Imgs:\t\t \t{image_id: image}\n        # 4)  Cats: \t \t{category_id: category_name}\n        # 5)  Sents:     \t{sent_id: sent}\n        # 6)  imgToRefs: \t{image_id: refs}\n        # 7)  imgToAnns: \t{image_id: anns}\n        # 8)  refToAnn:  \t{ref_id: ann}\n        # 9)  annToRef:  \t{ann_id: ref}\n        # 10) catToRefs: \t{category_id: refs}\n        # 11) sentToRef: \t{sent_id: ref}\n        # 12) sentToTokens: {sent_id: tokens}\n        print('creating index...')\n        # fetch info from instances\n        Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}\n        for ann in self.data['annotations']:\n            Anns[ann['id']] = ann\n            imgToAnns[ann['image_id']] = imgToAnns.get(\n                ann['image_id'], []) + [ann]\n        for img in self.data['images']:\n            Imgs[img['id']] = img\n        for cat in self.data['categories']:\n            Cats[cat['id']] = cat['name']\n\n        # fetch info from refs\n        Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}\n        Sents, sentToRef, sentToTokens = {}, {}, {}\n        for ref in self.data['refs']:\n            # ids\n            ref_id = ref['ref_id']\n            ann_id = ref['ann_id']\n            category_id = ref['category_id']\n            image_id = ref['image_id']\n\n            # add mapping related to ref\n            Refs[ref_id] = ref\n            imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]\n            catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]\n            refToAnn[ref_id] = Anns[ann_id]\n            annToRef[ann_id] = ref\n\n            # add mapping of sent\n            for sent in ref['sentences']:\n                Sents[sent['sent_id']] = sent\n                sentToRef[sent['sent_id']] = ref\n                sentToTokens[sent['sent_id']] = sent['tokens']\n\n        # create class members\n        self.Refs = Refs\n        self.Anns = Anns\n        self.Imgs = Imgs\n        self.Cats = Cats\n        self.Sents = Sents\n        self.imgToRefs = imgToRefs\n        self.imgToAnns = imgToAnns\n        self.refToAnn = refToAnn\n        self.annToRef = annToRef\n        self.catToRefs = catToRefs\n        self.sentToRef = sentToRef\n        self.sentToTokens = sentToTokens\n        print('index created.')\n\n    def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):\n        image_ids = image_ids if type(image_ids) == list else [image_ids]\n        cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]\n        ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]\n\n        if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:\n            refs = self.data['refs']\n        else:\n            if not len(image_ids) == 0:\n                refs = [self.imgToRefs[image_id] for image_id in image_ids]\n            else:\n                refs = self.data['refs']\n            if not len(cat_ids) == 0:\n                refs = [ref for ref in refs if ref['category_id'] in cat_ids]\n            if not len(ref_ids) == 0:\n                refs = [ref for ref in refs if ref['ref_id'] in ref_ids]\n            if not len(split) == 0:\n                if split in ['testA', 'testB', 'testC']:\n                    # we also consider testAB, testBC, ...\n                    refs = [ref for ref in refs if split[-1] in ref['split']]\n                elif split in ['testAB', 'testBC', 'testAC']:\n                    # rarely used I guess...\n                    refs = [ref for ref in refs if ref['split'] == split]\n                elif split == 'test':\n                    refs = [ref for ref in refs if 'test' in ref['split']]\n                elif split == 'train' or split == 'val':\n                    refs = [ref for ref in refs if ref['split'] == split]\n                else:\n                    print('No such split [{}]'.format(split))\n                    sys.exit()\n        ref_ids = [ref['ref_id'] for ref in refs]\n        return ref_ids\n\n    def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):\n        image_ids = image_ids if type(image_ids) == list else [image_ids]\n        cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]\n        ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]\n\n        if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:\n            ann_ids = [ann['id'] for ann in self.data['annotations']]\n        else:\n            if not len(image_ids) == 0:\n                lists = [self.imgToAnns[image_id]\n                         for image_id in image_ids if image_id in self.imgToAnns]  # list of [anns]\n                anns = list(itertools.chain.from_iterable(lists))\n            else:\n                anns = self.data['annotations']\n            if not len(cat_ids) == 0:\n                anns = [ann for ann in anns if ann['category_id'] in cat_ids]\n            ann_ids = [ann['id'] for ann in anns]\n            if not len(ref_ids) == 0:\n                ids = set(ann_ids).intersection(\n                    set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))\n        return ann_ids\n\n    def getImgIds(self, ref_ids=[]):\n        ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]\n\n        if not len(ref_ids) == 0:\n            image_ids = list(set([self.Refs[ref_id]['image_id']\n                             for ref_id in ref_ids]))\n        else:\n            image_ids = self.Imgs.keys()\n        return image_ids\n\n    def getCatIds(self):\n        return self.Cats.keys()\n\n    def loadRefs(self, ref_ids=[]):\n        if type(ref_ids) == list:\n            return [self.Refs[ref_id] for ref_id in ref_ids]\n        elif type(ref_ids) == int:\n            return [self.Refs[ref_ids]]\n\n    def loadAnns(self, ann_ids=[]):\n        if type(ann_ids) == list:\n            return [self.Anns[ann_id] for ann_id in ann_ids]\n        elif type(ann_ids) == int or type(ann_ids) == unicode:\n            return [self.Anns[ann_ids]]\n\n    def loadImgs(self, image_ids=[]):\n        if type(image_ids) == list:\n            return [self.Imgs[image_id] for image_id in image_ids]\n        elif type(image_ids) == int:\n            return [self.Imgs[image_ids]]\n\n    def loadCats(self, cat_ids=[]):\n        if type(cat_ids) == list:\n            return [self.Cats[cat_id] for cat_id in cat_ids]\n        elif type(cat_ids) == int:\n            return [self.Cats[cat_ids]]\n\n    def getRefBox(self, ref_id):\n        ref = self.Refs[ref_id]\n        ann = self.refToAnn[ref_id]\n        return ann['bbox']  # [x, y, w, h]\n\n    def showRef(self, ref, seg_box='seg'):\n        ax = plt.gca()\n        # show image\n        image = self.Imgs[ref['image_id']]\n        I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))\n        ax.imshow(I)\n        # show refer expression\n        for sid, sent in enumerate(ref['sentences']):\n            print('{}. {}'.format(sid+1, sent['sent']))\n        # show segmentations\n        if seg_box == 'seg':\n            ann_id = ref['ann_id']\n            ann = self.Anns[ann_id]\n            polygons = []\n            color = []\n            c = 'none'\n            if type(ann['segmentation'][0]) == list:\n                # polygon used for refcoco*\n                for seg in ann['segmentation']:\n                    poly = np.array(seg).reshape((len(seg)/2, 2))\n                    polygons.append(Polygon(poly, True, alpha=0.4))\n                    color.append(c)\n                p = PatchCollection(polygons, facecolors=color, edgecolors=(\n                    1, 1, 0, 0), linewidths=3, alpha=1)\n                ax.add_collection(p)  # thick yellow polygon\n                p = PatchCollection(polygons, facecolors=color, edgecolors=(\n                    1, 0, 0, 0), linewidths=1, alpha=1)\n                ax.add_collection(p)  # thin red polygon\n            else:\n                # mask used for refclef\n                rle = ann['segmentation']\n                m = mask.decode(rle)\n                img = np.ones((m.shape[0], m.shape[1], 3))\n                color_mask = np.array([2.0, 166.0, 101.0])/255\n                for i in range(3):\n                    img[:, :, i] = color_mask[i]\n                ax.imshow(np.dstack((img, m*0.5)))\n        # show bounding-box\n        elif seg_box == 'box':\n            ann_id = ref['ann_id']\n            ann = self.Anns[ann_id]\n            bbox = self.getRefBox(ref['ref_id'])\n            box_plot = Rectangle(\n                (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)\n            ax.add_patch(box_plot)\n\n    def getMask(self, ref):\n        # return mask, area and mask-center\n        ann = self.refToAnn[ref['ref_id']]\n        image = self.Imgs[ref['image_id']]\n        if type(ann['segmentation'][0]) == list:  # polygon\n            rle = mask.frPyObjects(\n                ann['segmentation'], image['height'], image['width'])\n        else:\n            rle = ann['segmentation']\n        m = mask.decode(rle)\n        # sometimes there are multiple binary map (corresponding to multiple segs)\n        m = np.sum(m, axis=2)\n        m = m.astype(np.uint8)  # convert to np.uint8\n        # compute area\n        area = sum(mask.area(rle))  # should be close to ann['area']\n        return {'mask': m, 'area': area}\n        # # position\n        # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)\n        # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style)    -> y (c style)\n        # # mass position (if there were multiple regions, we use the largest one.)\n        # label_m = label(m, connectivity=m.ndim)\n        # regions = regionprops(label_m)\n        # if len(regions) > 0:\n        # \tlargest_id = np.argmax(np.array([props.filled_area for props in regions]))\n        # \tlargest_props = regions[largest_id]\n        # \tmass_y, mass_x = largest_props.centroid\n        # else:\n        # \tmass_x, mass_y = position_x, position_y\n        # # if centroid is not in mask, we find the closest point to it from mask\n        # if m[mass_y, mass_x] != 1:\n        # \tprint 'Finding closes mask point ...'\n        # \tkernel = np.ones((10, 10),np.uint8)\n        # \tme = cv2.erode(m, kernel, iterations = 1)\n        # \tpoints = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist())  # row, col style\n        # \tpoints = np.array(points)\n        # \tdist   = np.sum((points - (mass_y, mass_x))**2, axis=1)\n        # \tid     = np.argsort(dist)[0]\n        # \tmass_y, mass_x = points[id]\n        # \t# return\n        # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}\n        # # show image and mask\n        # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))\n        # plt.figure()\n        # plt.imshow(I)\n        # ax = plt.gca()\n        # img = np.ones( (m.shape[0], m.shape[1], 3) )\n        # color_mask = np.array([2.0,166.0,101.0])/255\n        # for i in range(3):\n        #     img[:,:,i] = color_mask[i]\n        # ax.imshow(np.dstack( (img, m*0.5) ))\n        # plt.show()\n\n    def showMask(self, ref):\n        M = self.getMask(ref)\n        msk = M['mask']\n        ax = plt.gca()\n        ax.imshow(msk)\n\n\nif __name__ == '__main__':\n    refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',\n                  dataset='refcocog', splitBy='google')\n    ref_ids = refer.getRefIds()\n    print(len(ref_ids))\n\n    print(len(refer.Imgs))\n    print(len(refer.imgToRefs))\n\n    ref_ids = refer.getRefIds(split='train')\n    print('There are {} training referred objects.' % len(ref_ids))\n\n    for ref_id in ref_ids:\n        ref = refer.loadRefs(ref_id)[0]\n        if len(ref['sentences']) < 2:\n            continue\n\n        pprint(ref)\n        print('The label is {}.'.format(refer.Cats[ref['category_id']]))\n\n        # plt.figure()\n        # refer.showRef(ref, seg_box='box')\n        # plt.show()\n\n        # plt.figure()\n        # refer.showMask(ref)\n        # plt.show()\n"
  },
  {
    "path": "datasets_os/registration/__init__.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nfrom . import (\n    register_coco_panoptic_annos_grounding_interactive,\n    register_coco_instruct_grounding_dataset,\n    register_flickr_dataset,\n    # register_vg_dataset,\n)"
  },
  {
    "path": "datasets_os/registration/register_coco_instruct_grounding_dataset.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport json\nimport os\nimport collections\n\nfrom detectron2.data import DatasetCatalog, MetadataCatalog\nfrom detectron2.data.datasets import load_sem_seg\nfrom detectron2.data.datasets.builtin_meta import COCO_CATEGORIES\nfrom detectron2.utils.file_io import PathManager\nimport pycocotools.mask as mask_util\n\n\n_PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = {\n\n\n    \"coco_instruct_train_v3\": (\n        \"coco/train2014\", # image_root\n        \"coco/annotations/instances_train2017_gvc.json\", # annot_root\n        \"llava/annotations/grounded_visual_chat_data.json\",\n    ),\n\n    \"coco_interactive\": (\n        \"coco/train2014\", # image_root\n        \"coco/annotations/instances_train2014_filter.json\", # annot_root\n        \"llava/annotations/llava_instruct_150k_visual_prompt.json\",\n    ),\n    \"coco_interactive_refcoco\": (\n        \"coco/train2017\", # image_root\n        \"coco/annotations/instances_train2017_refcoco.json\", # annot_root\n        \"coco/annotations/grounding_train2017_instruct.json\",\n    ),\n}\n\n\ndef get_metadata():\n    meta = {}\n    return meta\n\n\ndef load_coco_json(image_root, annot_json,conversation, metadata):\n    \"\"\"\n    Args:\n        image_dir (str): path to the raw dataset. e.g., \"~/coco/train2017\".\n        gt_dir (str): path to the raw annotations. e.g., \"~/coco/panoptic_train2017\".\n        json_file (str): path to the json file. e.g., \"~/coco/annotations/panoptic_train2017.json\".\n    Returns:\n        list[dict]: a list of dicts in Detectron2 standard format. (See\n        `Using Custom Datasets </tutorials/datasets.html>`_ )\n    \"\"\"\n\n    with PathManager.open(annot_json) as f:\n        json_info = json.load(f)\n        \n    # build dictionary for grounding\n    grd_dict = collections.defaultdict(list)\n\n    imgid2image = {}\n    for image in json_info[\"images\"]:\n        image_id = image[\"id\"]\n        imgid2image[image_id] = image\n    for grd_ann in json_info['annotations']:\n        image_id = int(grd_ann[\"image_id\"])\n        segm = grd_ann.get(\"segmentation\", None)\n        if segm:  # either list[list[float]] or dict(RLE)\n            if isinstance(segm, dict):\n                if isinstance(segm[\"counts\"], list):\n                    # convert to compressed RLE\n                    segm = mask_util.frPyObjects(segm, *segm[\"size\"])\n\n            grd_ann[\"segmentation\"] = segm\n        grd_dict[image_id].append(grd_ann)\n\n    conv_dict = collections.defaultdict(list)\n    with open(conversation) as f:\n        data = json.load(f)\n    for d in data:\n        image_id = int(d['id'])\n        if 'gd_ls' not in d:\n            d['gd_ls']=None\n        if 'q_gd_ls' in d:\n            conv_dict[image_id].append((d['conversations'],d['q_gd_ls']))\n        else:\n            conv_dict[image_id].append((d['conversations'], d['gd_ls']))\n\n    ret = []\n    for d in data:\n        image_id = int(d['id'])\n        image= imgid2image[image_id]\n        image_file = os.path.join(image_root, image['file_name'])\n        grounding_anno = grd_dict[image_id]\n        if image_id in conv_dict and len(conv_dict[image_id])>0:\n            ret.append(\n                {\n                    \"file_name\": image_file,\n                    \"image_id\": image_id,\n                    \"grounding_info\": grounding_anno,\n                    \"conversations\": conv_dict[image_id],\n                }\n            )\n\n    assert len(ret), f\"No images found in {image_root}!\"\n    assert PathManager.isfile(ret[0][\"file_name\"]), ret[0][\"file_name\"]\n    return ret\n\n\ndef register_coco(\n    name, metadata, image_root, annot_json,conversation):\n    DatasetCatalog.register(\n        name,\n        lambda: load_coco_json(image_root, annot_json,conversation, metadata),\n    )\n    MetadataCatalog.get(name).set(\n        image_root=image_root,\n        json_file=annot_json,\n        evaluator_type=\"grounding_refcoco\",\n        ignore_label=255,\n        label_divisor=1000,\n        **metadata,\n    )\n\n\ndef register_all_coco(root):\n    for (\n        prefix,\n        (image_root, annot_root,conversation_path),\n    ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items():\n        register_coco(\n            prefix,\n            get_metadata(),\n            os.path.join(root, image_root),\n            os.path.join(root, annot_root),\n            conversation_path,\n        )\n\n_root = os.getenv(\"DATASET\", \"datasets\")\nregister_all_coco(_root)\n"
  },
  {
    "path": "datasets_os/registration/register_coco_panoptic_annos_grounding_interactive.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport json\nimport os\nimport collections\n\nfrom detectron2.data import DatasetCatalog, MetadataCatalog\nfrom detectron2.data.datasets import load_sem_seg\nfrom detectron2.data.datasets.builtin_meta import COCO_CATEGORIES\nfrom detectron2.utils.file_io import PathManager\n\n\n_PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = {\n\n\n    \"coco_2017_train_panoptic_ref_full\": (\n        # This is the original panoptic annotation directory\n        \"coco/panoptic_train2017\",\n        \"coco/annotations/panoptic_train2017_filter.json\",\n        \"coco/panoptic_semseg_train2017\",\n        \"coco/annotations/grounding_train2017.json\",\n    ),\n\n\n}\n\n\ndef get_metadata():\n    meta = {}\n    # The following metadata maps contiguous id from [0, #thing categories +\n    # #stuff categories) to their names and colors. We have to replica of the\n    # same name and color under \"thing_*\" and \"stuff_*\" because the current\n    # visualization function in D2 handles thing and class classes differently\n    # due to some heuristic used in Panoptic FPN. We keep the same naming to\n    # enable reusing existing visualization functions.\n    thing_classes = [k[\"name\"] for k in COCO_CATEGORIES if k[\"isthing\"] == 1]\n    thing_colors = [k[\"color\"] for k in COCO_CATEGORIES if k[\"isthing\"] == 1]\n    stuff_classes = [k[\"name\"] for k in COCO_CATEGORIES]\n    stuff_colors = [k[\"color\"] for k in COCO_CATEGORIES]\n\n    meta[\"thing_classes\"] = thing_classes\n    meta[\"thing_colors\"] = thing_colors\n    meta[\"stuff_classes\"] = stuff_classes\n    meta[\"stuff_colors\"] = stuff_colors\n\n    # Convert category id for training:\n    #   category id: like semantic segmentation, it is the class id for each\n    #   pixel. Since there are some classes not used in evaluation, the category\n    #   id is not always contiguous and thus we have two set of category ids:\n    #       - original category id: category id in the original dataset, mainly\n    #           used for evaluation.\n    #       - contiguous category id: [0, #classes), in order to train the linear\n    #           softmax classifier.\n    thing_dataset_id_to_contiguous_id = {}\n    stuff_dataset_id_to_contiguous_id = {}\n\n    for i, cat in enumerate(COCO_CATEGORIES):\n        if cat[\"isthing\"]:\n            thing_dataset_id_to_contiguous_id[cat[\"id\"]] = i\n        # else:\n        #     stuff_dataset_id_to_contiguous_id[cat[\"id\"]] = i\n\n        # in order to use sem_seg evaluator\n        stuff_dataset_id_to_contiguous_id[cat[\"id\"]] = i\n\n    meta[\"thing_dataset_id_to_contiguous_id\"] = thing_dataset_id_to_contiguous_id\n    meta[\"stuff_dataset_id_to_contiguous_id\"] = stuff_dataset_id_to_contiguous_id\n\n    return meta\n\n\ndef load_coco_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, grounding_file, meta):\n    \"\"\"\n    Args:\n        image_dir (str): path to the raw dataset. e.g., \"~/coco/train2017\".\n        gt_dir (str): path to the raw annotations. e.g., \"~/coco/panoptic_train2017\".\n        json_file (str): path to the json file. e.g., \"~/coco/annotations/panoptic_train2017.json\".\n    Returns:\n        list[dict]: a list of dicts in Detectron2 standard format. (See\n        `Using Custom Datasets </tutorials/datasets.html>`_ )\n    \"\"\"\n\n    def _convert_category_id(segment_info, meta):\n        if segment_info[\"category_id\"] in meta[\"thing_dataset_id_to_contiguous_id\"]:\n            segment_info[\"category_id\"] = meta[\"thing_dataset_id_to_contiguous_id\"][\n                segment_info[\"category_id\"]\n            ]\n            segment_info[\"isthing\"] = True\n        else:\n            segment_info[\"category_id\"] = meta[\"stuff_dataset_id_to_contiguous_id\"][\n                segment_info[\"category_id\"]\n            ]\n            segment_info[\"isthing\"] = False\n        return segment_info\n\n    with PathManager.open(json_file) as f:\n        json_info = json.load(f)\n    \n    with PathManager.open(grounding_file) as f:\n        grounding_info = json.load(f)\n\n    # build dictionary for grounding\n    grd_dict = collections.defaultdict(list)\n    for grd_ann in grounding_info['annotations']:\n        image_id = int(grd_ann[\"image_id\"])\n        grd_dict[image_id].append(grd_ann)\n    \n    ret = []\n    for ann in json_info[\"annotations\"]:\n        image_id = int(ann[\"image_id\"])\n        # TODO: currently we assume image and label has the same filename but\n        # different extension, and images have extension \".jpg\" for COCO. Need\n        # to make image extension a user-provided argument if we extend this\n        # function to support other COCO-like datasets.\n        image_file = os.path.join(image_dir, os.path.splitext(ann[\"file_name\"])[0] + \".jpg\")\n        label_file = os.path.join(gt_dir, ann[\"file_name\"])\n        sem_label_file = os.path.join(semseg_dir, ann[\"file_name\"])\n        segments_info = [_convert_category_id(x, meta) for x in ann[\"segments_info\"]]\n\n        grounding_anno = grd_dict[image_id] if image_id in grd_dict else []\n        ret.append(\n            {\n                \"file_name\": image_file,\n                \"image_id\": image_id,\n                \"grounding_info\": grounding_anno,\n                \"pan_seg_file_name\": label_file,\n                \"sem_seg_file_name\": sem_label_file,\n                \"segments_info\": segments_info,\n            }\n        )\n    assert len(ret), f\"No images found in {image_dir}!\"\n    assert PathManager.isfile(ret[0][\"file_name\"]), ret[0][\"file_name\"]\n    assert PathManager.isfile(ret[0][\"pan_seg_file_name\"]), ret[0][\"pan_seg_file_name\"]\n    assert PathManager.isfile(ret[0][\"sem_seg_file_name\"]), ret[0][\"sem_seg_file_name\"]\n    return ret\n\n\ndef register_coco_panoptic_annos_caption_grounding_sem_seg(\n    name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, grounding_root, instances_json\n):\n    panoptic_name = '_'.join(name.split('_')[0:4])\n    delattr(MetadataCatalog.get(panoptic_name), \"thing_classes\")\n    delattr(MetadataCatalog.get(panoptic_name), \"thing_colors\")\n    MetadataCatalog.get(panoptic_name).set(\n        thing_classes=metadata[\"thing_classes\"],\n        thing_colors=metadata[\"thing_colors\"],\n        # thing_dataset_id_to_contiguous_id=metadata[\"thing_dataset_id_to_contiguous_id\"],\n    )\n    \n    # the name is \"coco_2017_train_panoptic_with_sem_seg\" and \"coco_2017_val_panoptic_with_sem_seg\"\n    semantic_name = name + \"_with_sem_seg_caption_grounding\"\n    DatasetCatalog.register(\n        semantic_name,\n        lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, sem_seg_root, grounding_root, metadata),\n    )\n\n    MetadataCatalog.get(semantic_name).set(\n        sem_seg_root=sem_seg_root,\n        panoptic_root=panoptic_root,\n        image_root=image_root,\n        panoptic_json=panoptic_json,\n        json_file=instances_json,\n        evaluator_type=\"coco_panoptic_seg_interactive\",\n        ignore_label=255,\n        label_divisor=1000,\n        **metadata,\n    )\n\n\ndef register_all_coco_panoptic_annos_caption_grounding_sem_seg(root):\n    for (\n        prefix,\n        (panoptic_root, panoptic_json, semantic_root, grounding_root),\n    ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items():\n        prefix_instances = '_'.join(prefix.split('_')[0:3])\n        instances_meta = MetadataCatalog.get(prefix_instances)\n        image_root, instances_json = instances_meta.image_root, instances_meta.json_file\n        # image_root = image_root.replace('datasets', root)\n\n        register_coco_panoptic_annos_caption_grounding_sem_seg(\n            prefix,\n            get_metadata(),\n            image_root,\n            os.path.join(root, panoptic_root),\n            os.path.join(root, panoptic_json),\n            os.path.join(root, semantic_root),\n            os.path.join(root, grounding_root),\n            \n            os.path.join(root, instances_json),\n        )\n\n\n_root = os.getenv(\"DATASET\", \"datasets\")\nregister_all_coco_panoptic_annos_caption_grounding_sem_seg(_root)"
  },
  {
    "path": "datasets_os/registration/register_flickr_dataset.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport json\nimport os\nimport collections\n\nfrom detectron2.data import DatasetCatalog, MetadataCatalog\nfrom detectron2.utils.file_io import PathManager\n\n\n_PREDEFINED_SPLITS = {\n\n    \"flickr_val\": (\n        \"flickr30k_entities/val\", # image_root\n        \"final_flickr_separateGT_val.json\", # # anno_path\n    ),\n    \"flickr_train\": (\n        \"flickr30k_entities/train\", # image_root\n        \"final_flickr_separateGT_train.json\", # # anno_path\n    ),\n}\n\n\ndef get_metadata():\n    meta = {}\n    return meta\n\n\ndef load_flickr_json(image_root, annot_json, metadata):\n\n\n    with PathManager.open(annot_json) as f:\n        json_info = json.load(f)\n        \n    # build dictionary for grounding\n    grd_dict = collections.defaultdict(list)\n    # for grd_ann in json_info['annotations']:\n    #     image_id = int(grd_ann[\"image_id\"])\n    #     grd_dict[image_id].append(grd_ann)\n    for grd_ann in json_info['annotations']:\n        image_id = int(grd_ann[\"image_id\"])\n        grd_dict[image_id].append(grd_ann)\n\n    ret = []\n    for image in json_info[\"images\"]:\n        image_id = int(image[\"id\"])\n        caption=image['caption']\n        image_file = os.path.join(image_root, image['file_name'])\n        grounding_anno = grd_dict[image_id]\n        ret.append(\n            {\n                \"file_name\": image_file,\n                \"image_id\": image_id,\n                \"grounding_info\": grounding_anno,\n                \"caption\": caption,\n            }\n        )\n    assert len(ret), f\"No images found in {image_root}!\"\n    assert PathManager.isfile(ret[0][\"file_name\"]), ret[0][\"file_name\"]\n    return ret\n\n\ndef register_flickr(\n    name, metadata, image_root, annot_json):\n    DatasetCatalog.register(\n        name,\n        lambda: load_flickr_json(image_root, annot_json, metadata),\n    )\n    MetadataCatalog.get(name).set(\n        image_root=image_root,\n        json_file=annot_json,\n        evaluator_type=\"grounding_refcoco\",\n        ignore_label=255,\n        label_divisor=1000,\n        **metadata,\n    )\n\n\ndef register_all_flickr(root,anno_root):\n    for (\n        prefix,\n        (image_root, anno_path),\n    ) in _PREDEFINED_SPLITS.items():\n        register_flickr(\n            prefix,\n            get_metadata(),\n            os.path.join(root, image_root),\n            os.path.join(root,anno_root, anno_path),\n        )\n\n_root = os.getenv(\"DATASET\", \"datasets\")\nann_root = os.getenv(\"Flickr\", \"flickr30k_entities/annotations\")\nregister_all_flickr(_root,ann_root)\n"
  },
  {
    "path": "datasets_os/registration/register_vg_dataset.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport json\nimport os\nimport collections\n\nfrom detectron2.data import DatasetCatalog, MetadataCatalog\nfrom detectron2.utils.file_io import PathManager\n\n\n_PREDEFINED_SPLITS = {\n\n    \"vg_train\": (\n        \"vg/images/\", # image_root\n        \"train.json\", # anno_path\n    ),\n}\n\n\ndef get_metadata():\n    meta = {}\n    return meta\n\n\ndef load_vg_json(image_root, annot_json, metadata):\n\n\n    with PathManager.open(annot_json) as f:\n        json_info = json.load(f)\n        \n    # build dictionary for grounding\n    grd_dict = collections.defaultdict(list)\n    for grd_ann in json_info['annotations']:\n        image_id = int(grd_ann[\"image_id\"])\n        grd_dict[image_id].append(grd_ann)\n\n    ret = []\n    for image in json_info[\"images\"]:\n        image_id = int(image[\"id\"])\n        image_file = os.path.join(image_root, image['file_name'])\n        grounding_anno = grd_dict[image_id]\n        ret.append(\n            {\n                \"file_name\": image_file,\n                \"image_id\": image_id,\n                \"annotations\": grounding_anno,\n            }\n        )\n    assert len(ret), f\"No images found in {image_root}!\"\n    assert PathManager.isfile(ret[0][\"file_name\"]), ret[0][\"file_name\"]\n    return ret\n\n\ndef register_vg(\n    name, metadata, image_root, annot_json):\n    DatasetCatalog.register(\n        name,\n        lambda: load_vg_json(image_root, annot_json, metadata),\n    )\n    MetadataCatalog.get(name).set(\n        image_root=image_root,\n        json_file=annot_json,\n        evaluator_type=\"grounding_refcoco\",\n        ignore_label=255,\n        label_divisor=1000,\n        **metadata,\n    )\n\n\ndef register_all_vg(root,anno_root):\n    for (\n        prefix,\n        (image_root, anno_path),\n    ) in _PREDEFINED_SPLITS.items():\n        register_vg(\n            prefix,\n            get_metadata(),\n            os.path.join(root, image_root),\n            os.path.join(root,anno_root, anno_path),\n        )\n\n_root = os.getenv(\"DATASET\", \"datasets\")\nanno_root = os.getenv(\"VG\", \"vg/annotations/\")\nregister_all_vg(_root,anno_root)\n"
  },
  {
    "path": "datasets_os/semseg_loader.py",
    "content": "from PIL import Image\nimport scipy.io\nimport numpy as np\n\ndef load_semseg(filename, loader_type):\n    if loader_type == 'PIL':\n        semseg = np.array(Image.open(filename), dtype=np.int)\n    elif loader_type == 'MAT':\n        semseg = scipy.io.loadmat(filename)['LabelMap']\n    return semseg"
  },
  {
    "path": "docs/MODEL_ZOO.md",
    "content": "# LLaVA-Grounding Checkpoints\n\nWe will continuously update the model zoo.\n\n| Model Name | LLM version | Model Config | Weights |\n|------------|:---------------:|:-------------:|:-----------:|\n| LLaVA_Grounding_v0_7b | vicuna-v0-7b | [[grounding-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml), [visual-prompt-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/semsam/visual_prompt_encoder.yaml)]<br>(0.3B in total) | [HuggingFace](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp) |\n"
  },
  {
    "path": "gradio_demo/LLaVA_G_Demo.py",
    "content": "\nimport gradio as gr\nimport os\nimport cv2\n\nimport torch\nimport numpy as np\nfrom llava.eval.LLaVA_G_Eval import Evaluator_MM_Inter\nfrom llava import conversation as conversation_lib\nfrom llava.mm_utils import tokenizer_image_token\nfrom llava.constants import DEFAULT_IMAGE_TOKEN\n\ndef get_image_name(dir_save=\"./gradio_demo/tmp_files\", prefix=\"click_img_\"):\n    import os\n    files = os.listdir(dir_save)\n    file_orders = [int(file.split(\".\")[0][len(prefix):]) for file in files if file.endswith(\".jpg\") and file.startswith(prefix)]\n    if len(file_orders) == 0:\n        return os.path.join(dir_save, prefix + \"0.jpg\")\n    else:\n        return os.path.join(dir_save, prefix + str(max(file_orders) + 1) + \".jpg\")\ndef preprocess_multi_conv(\n    sources,\n    tokenizer,\n    has_image = False\n):\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n    conv.messages = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n    conv_prompt = conv.get_prompt()\n    conv_prompt = \"ASSISTANT: \".join(conv_prompt.split(\"ASSISTANT: \")[:-1]) + \"ASSISTANT:\"\n    conv_prompt = conv_prompt.replace(\"</s>\", \"\")\n    conversations = [conv_prompt]\n    print(\"Input Prompt: \", conv_prompt)\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\ndef filter_empty_box_mask(text, boxes_image, masks_image):\n    def extract_text(sentence):\n        # Use regular expression to find and extract the text and number\n        import re\n        pattern = r\"<g_s>|<g_e> <seg>\"\n        cleaned_text = re.sub(pattern, '', sentence)\n        return cleaned_text\n    if len(boxes_image) == 0:\n        return text, boxes_image, masks_image\n    else:\n        sub_texts = text.split(\" <seg>\")\n        sub_texts_filtered = []\n        boxes_image_filtered = []\n        masks_image_filtered = []\n        for box_per_gd, mask_per_gd, text_per_gd in zip(boxes_image, masks_image, sub_texts):\n            text_per_gd += \" <seg>\"\n            ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0)\n            if len(ind_nonempty_box[0]) < box_per_gd.shape[0]:  # empty box encountered\n                if len(ind_nonempty_box[0]) == 0:\n                    text_per_gd = \" \" + \" \".join(extract_text(text_per_gd).split())\n                    sub_texts_filtered.append(text_per_gd)  # box is desperated\n                    continue\n                else:\n                    box_per_gd = box_per_gd[ind_nonempty_box]\n                    mask_per_gd = mask_per_gd[ind_nonempty_box]\n                    boxes_image_filtered.append(box_per_gd)\n                    masks_image_filtered.append(mask_per_gd)\n                    sub_texts_filtered.append(text_per_gd)\n            else:\n                boxes_image_filtered.append(box_per_gd)\n                masks_image_filtered.append(mask_per_gd)\n                sub_texts_filtered.append(text_per_gd)\n        sub_texts_filtered.append(sub_texts[-1])\n        text_filtered = \"\".join(sub_texts_filtered)\n        return text_filtered, boxes_image_filtered, masks_image_filtered\n\nclass InferenceDemo(object):\n    def __init__(self, \n                 model_path, \n                 path_vision_cfg, \n                 path_inter_cfg, \n    ) -> None:\n        self.model_backend = Evaluator_MM_Inter(\n            model_path=model_path,\n            path_vision_model_cfg=path_vision_cfg,\n            path_inter_model_cfg =path_inter_cfg,\n        )\n        self.model_backend.data_mapper.preprocess = preprocess_multi_conv\n\n    def hitory2datadict(self, history, text):\n        def filter_valid_conversations(history):\n            def delete_color(text):\n                import re\n                pattern = re.compile(r'<span style=.*?>(.*?)</span>', re.DOTALL)  \n                \n                clean_text = pattern.sub(r'\\1', text)  \n                \n                return clean_text\n\n            valid_conversations = history[3:]\n            valid_conversations = [aa for aa in valid_conversations if not (None in aa)]\n            valid_conversations = [[delete_color(aa[0]), delete_color(aa[1])] for aa in valid_conversations]\n            return valid_conversations\n        valid_conversations = filter_valid_conversations(history)\n        dataset_dict = {\n            \"file_name\": history[1][0][0],\n            \"image_id\": 0,\n            \"question_id\": 0,\n        }\n        dataset_dict['conversations'] = []\n        for valid_conv in valid_conversations:\n            conv = [\n                {\n                    \"from\": \"human\", \n                    \"value\": valid_conv[0]\n                },\n                {\n                    \"from\": \"gpt\", \n                    \"value\": valid_conv[1]\n                }\n            ]\n            dataset_dict['conversations'].append([conv, None])\n        conv = [\n            {\n                \"from\": \"human\", \n                \"value\": text\n            },\n            {\n                \"from\": \"gpt\", \n                \"value\": \"Placeholder.\"\n            }\n        ]\n        dataset_dict['conversations'].append([conv, None])\n        dataset_dict['conversations'][0][0][0][\"value\"] = DEFAULT_IMAGE_TOKEN + \" \" + dataset_dict['conversations'][0][0][0][\"value\"]\n\n        return dataset_dict\n    def inference(self, data_dict):\n        # TODO: Implement data_mapper.\n        data_dict = self.model_backend.data_mapper(data_dict)[0]\n        #\n        device = self.model_backend.model.device\n        for key, value in data_dict.items():\n            if isinstance(value, torch.Tensor):\n                data_dict[key] = value.to(device)\n        \n        response_text, response_boxes, response_mask, mask_inter = self.model_backend.evaluate_sample([data_dict])\n        #\n        response_text, response_boxes, response_mask = filter_empty_box_mask(response_text, response_boxes, response_mask)\n        return response_text, response_boxes, response_mask, mask_inter\n\n\ndef generate_distinct_colors(count):\n    import colorsys\n    import random\n    random.seed(0)\n    hues = [i/count for i in range(count)]\n    random.shuffle(hues)\n\n    colors = []\n    for hue in hues:\n        rgb = colorsys.hsv_to_rgb(hue, 1, 1)\n        rgb = tuple(int(val * 255) for val in rgb)\n        colors.append(rgb)\n\n    return colors\n\n\ndef add_text(history, text, image, threshold_slider, temporature_slider, interaction_selector):\n    # add a text to history stream. and leave the response as None for you to fill in bot.\n    def response2stream(response, question):\n        return [[question, response]]\n    def post_process_text_response(text):\n        def find_start_idxes(sentence, word):\n            window_size = len(word)\n            start_indexes = []\n            assert len(sentence) > window_size\n            if sentence == window_size:\n                return [0]\n            for start_index in range(len(sentence) - window_size):\n                if sentence[start_index: start_index + window_size] == word:\n                    start_indexes.append(start_index)\n            return start_indexes\n        def add_color_to_text(obj_id, text):\n            color = colors[obj_id]\n            text = f\"<span style='color: rgb{color};'>{text}</span>\"\n            return text\n        def format_sentence(splitted_sentence):\n            joint_sentence = \" \".join(splitted_sentence)\n            return joint_sentence\n        def extract_text(sentence):\n            import re\n            pattern = r\"<g_s>|<g_e>\"\n            cleaned_text = re.sub(pattern, '', sentence)\n            return cleaned_text\n        \n        text_pure = \"\"\n        seg_start_index = find_start_idxes(text, \"<seg>\")\n        if len(seg_start_index) > 0:\n            count_obj = 0\n            subtexts = text.split(\" <seg>\")\n            for subtext in subtexts:\n                if \"<g_s>\" in subtext:\n                    start_idx = find_start_idxes(subtext, \"<g_s>\")[0]\n                    text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())])\n                    text_ = extract_text(subtext[start_idx:])\n                    text_pure += add_color_to_text(count_obj, text_)\n                    count_obj += 1\n                else:\n                    text_pure = format_sentence([text_pure, format_sentence(subtext.split())])\n        else:\n            text_pure = text\n        return text_pure\n    def post_process_gd_response(path_ori_image, gd_results_per_image):\n        def unresize_box(box, width, height):\n            ratio = min(width, height) / max(width, height)\n            if width > height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n                box[:, 1] = box[:, 1] / ratio\n                box[:, 3] = box[:, 3] / ratio\n            elif width < height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n                box[:, 0] = box[:, 0] / ratio\n                box[:, 2] = box[:, 2] / ratio\n            return box\n        image = cv2.imread(path_ori_image)\n        height, width = image.shape[:2]\n        gd_results_per_image = [unresize_box(aa.detach().cpu(), width, height) for aa in gd_results_per_image]\n        for gd_id, gd_result in enumerate(gd_results_per_image):\n            bboxes = gd_result.cpu().tolist()\n            for bbox in bboxes:\n                bbox = [int(bbox[0]*width), int(bbox[1]*height), int(bbox[2]*width), int(bbox[3]*height)]\n                cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), colors[gd_id][::-1], 2)\n        path_save = get_image_name(prefix=\"grounding_img_\")\n        cv2.imwrite(path_save, image)\n        return (path_save, )\n    def post_process_masks(path_ori_image, mask_inter, path_gd_image, masks_gd, loc_inter, inter_type):\n        def unresize_mask(mask, width, height):\n            import torch.nn.functional as F\n            if width >= height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n                mask = F.interpolate((mask[None, ...]).float(), size=[width, width], mode=\"nearest\")[0]\n                mask = mask[:, :height]\n            elif width < height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n                mask = F.interpolate((mask[None, ...]).float(), size=[height, height], mode=\"nearest\")[0]\n                mask = mask[:, :, :width]\n            return mask\n        def unnormalize_inter(mask, loc):\n            height, width, _ = mask.shape\n            loc_x_mean, loc_y_mean, loc_w, loc_h = loc\n            if height >= width:\n                loc_x_mean = loc_x_mean / (width/height)\n                loc_w = loc_w / (width/height)\n            else:\n                loc_y_mean = loc_y_mean / (height/width)\n                loc_h = loc_h / (height/width)\n            return [loc_x_mean-loc_w/2, loc_y_mean-loc_h/2, loc_x_mean+loc_w/2, loc_y_mean+loc_h/2]\n        image = cv2.imread(path_ori_image)\n        gd_image = cv2.imread(path_gd_image)\n        returns = []\n        if not (mask_inter is None):\n            mask_ = (mask_inter[0][..., None] > 0).float().cpu().numpy()\n            mask_ = cv2.resize(mask_, (max(image.shape[0], image.shape[1]), max(image.shape[0], image.shape[1])))\n            mask_ = mask_[:image.shape[0], :image.shape[1]]\n            mask_ = mask_[..., None] * np.array([155, 155, 155])[None, None, :]\n            image = (image * 0.5 + mask_ * 0.5).astype(np.uint8)\n            if inter_type.lower() == \"box\":\n                loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])]\n                thickness = max(int(max(image.shape) / 1000 * 5), 1)\n                print(thickness)\n                cv2.rectangle(image, (int(loc_inter_unnormalized[0][0]*image.shape[1]), int(loc_inter_unnormalized[0][1]*image.shape[0])), (int((loc_inter_unnormalized[0][2])*image.shape[1]), int((loc_inter_unnormalized[0][3])*image.shape[0])), (255, 255, 255), thickness)\n            elif inter_type.lower() == \"click\":\n                loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])]\n                thickness = max(int(max(image.shape) / 1000 * 10), 1)\n                print(thickness)\n                cv2.circle(image, (int(((loc_inter_unnormalized[0][0] + loc_inter_unnormalized[0][2])/2)*image.shape[1]), int(((loc_inter_unnormalized[0][1] + loc_inter_unnormalized[0][3])/2)*image.shape[0])), thickness, (255, 255, 255), thickness=-1)\n            path_save = get_image_name(prefix=\"seg_inter_\")\n            cv2.imwrite(path_save, image)\n            returns.append((path_save, ))\n        else:\n            returns.append(None)\n        if not (masks_gd is None):\n            height, width = image.shape[:2]\n            masks_gd = [unresize_mask(aa, width, height) for aa in masks_gd]\n            colored_mask = torch.zeros((3, height, width), dtype=torch.long)\n            for gd_id, gd_mask_result in enumerate(masks_gd):\n                gd_mask_result = gd_mask_result.sum(dim=0, keepdim=True)\n                colored_mask[:, gd_mask_result[0] > 0.5] = torch.tensor(colors[gd_id][::-1])[:, None]\n            gd_image = (gd_image * 0.6 + colored_mask.permute(1,2,0).numpy() * 0.4).astype(np.uint8)\n            path_save_gd = get_image_name(prefix=\"seg_gd_\")\n            cv2.imwrite(path_save_gd, gd_image)\n            returns.append((path_save_gd, ))\n        else:\n            returns.append(None)\n        return returns\n    def mask2point(mask, inter_type):\n        height, width = mask.shape[:2]\n        ys, xs = np.where(mask[..., 0] == 255)\n        if inter_type.lower() == \"click\":\n            loc_x = xs.mean()\n            loc_y = ys.mean()\n            loc_x = loc_x / width\n            loc_y = loc_y / height\n            if height >= width:\n                loc_x = loc_x * (width/height)\n            else:\n                loc_y = loc_y * (height/width)\n            return torch.tensor([[loc_x, loc_y, 0.006, 0.006]])\n        elif inter_type.lower() == \"box\":\n            loc_x_min = xs.min() / width\n            loc_x_max = xs.max() / width\n            loc_y_min = ys.min() / height\n            loc_y_max = ys.max() / height\n            if height >= width:\n                loc_x_min = loc_x_min * (width/height)\n                loc_x_max = loc_x_max * (width/height)\n            else:\n                loc_y_min = loc_y_min * (height/width)\n                loc_y_max = loc_y_max * (height/width)\n            width = loc_x_max - loc_x_min\n            height = loc_y_max - loc_y_min\n            return torch.tensor([[(loc_x_min + loc_x_max)/2, (loc_y_min + loc_y_max)/2, width, height]])\n    if len(history) < 3:\n        response_text = \"Please upload an image first.\"\n    else:\n        loc_inter = mask2point(image[\"mask\"], interaction_selector)\n        is_interactive = torch.isnan(loc_inter[0]).sum() == 0\n        if is_interactive:\n            input_data_dict = our_chatbot.hitory2datadict(history, text)\n            input_data_dict[\"points\"] = loc_inter\n            input_data_dict[\"mode_inter\"] = interaction_selector\n        else:\n            input_data_dict = our_chatbot.hitory2datadict(history, text)\n            input_data_dict[\"points\"] = None\n            input_data_dict[\"mode_inter\"] = None\n        input_data_dict[\"matching_threshold\"] = threshold_slider\n        input_data_dict[\"temporature\"] = temporature_slider\n        response_text, response_gd, response_mask, mask_inter  = our_chatbot.inference(input_data_dict)\n    response_msks  = post_process_masks(history[1][0][0], mask_inter, history[1][0][0], response_mask, loc_inter, interaction_selector)\n    if \"<seg>\" in response_text:\n        response_gd = post_process_gd_response(response_msks[1][0], response_gd)\n        response_msks[1] = list(response_msks[1])\n        response_msks[1][0] = response_gd[0]\n        response_msks[1] = tuple(response_msks[1])\n    response_text  = post_process_text_response(response_text)\n    history += response2stream(response_text, text)\n    for response_msk in response_msks:\n        if not (response_msk is None):\n            history += response2stream(response_msk, None)\n    return history, None\ndef add_image(history, image):\n    print(\"LOG. Add Image Function is called.\")\n    path_input_img = get_image_name(prefix=\"tmp_input_img_\")\n    cv2.imwrite(path_input_img, image[\"image\"][..., ::-1])\n    if len(history) > 0:\n        history = [(None, \"A new image recieved, I will clear the history conversations.\")]\n    else:\n        history = [(None, None)]  # just to align with the above one, to determin where the image_path is.\n    history = history + [((path_input_img, ), None)]\n    history = history + [(None, \"Let't talk about this image!\")]\n    return history\n\ndef add_interaction_click(history, image, interaction_selector):\n    print(\"LOG. Add Interaction Function is called.\")\n    if interaction_selector.lower() == \"box\":\n        history = history + [(None, \"A more detailed box is specified, lets further talk about the region inside the box.\")]\n    elif interaction_selector.lower() == \"click\":\n        history = history + [(None, \"A more detailed click is specified, lets further talk about the region around the click.\")]\n    \n    mask = image[\"mask\"][..., :3] * np.array([234, 176, 113])\n    image_rgb = image[\"image\"][..., ::-1]\n    image_clicked = (image_rgb * 0.6 + mask * 0.4).astype(np.uint8)\n    path_save = get_image_name(prefix=\"click_img_\")\n    cv2.imwrite(path_save, image_clicked)\n    return history\n\ndef bot(history):\n    yield history\n\ndef clear_history(history, txt, img):\n    return None, None, None\ndef clear_response(history):\n    for index_conv in range(1, len(history)):\n        # loop until get a text response from our model.\n        conv = history[-index_conv]\n        if not (conv[0] is None):\n            break\n    question = history[-index_conv][0]\n    history = history[:-index_conv]\n    return history, question\n    \ndef upvote_one(history):\n    print(\"TODO: Implement upvote_one function.\")\n    pass\ndef downvote_one(history):\n    print(\"TODO: Implement downvote_one function.\")\n    pass\ndef flag_one(history):\n    print(\"TODO: Implement flag_one function.\")\n    pass\n#? defined here for later renderring.\ntxt = gr.Textbox(\n    scale=4,\n    show_label=False,\n    placeholder=\"Enter text and press enter, or upload an image. Append '(with grounding)' if you want to do grounding.\",\n    container=False,\n)\nwith gr.Blocks() as demo:\n    # Informations\n    title_markdown = (\"\"\"\n        # LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models\n        [[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)]  [[Demo](http://llava-grounding.xyzou.net:6084)]  [[Model](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp)] \n    \"\"\")\n    tips_markdown = (\"\"\"\n    **Tips for better results**\n    1. Adjust 'Threshold' according to the results or change your expression and click 'Regenerate' may help get better results. \n    2. Set temporature to 0.0 get reproducible results.\n    \"\"\")\n    tos_markdown = (\"\"\"\n    **Terms of use:**   By using this service, users are required to agree to the following terms:\n    The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.\n    Please click the \"Flag\" button if you get any inappropriate answer! We will collect those to keep improving our moderator.\n    For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n    \"\"\")\n    learn_more_markdown = (\"\"\"\n    **License:**   The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n    \"\"\")\n    models = [\n        \"LLaVA-Grounding-7B\",\n    ]\n    interactions = [\n        \"Click\",\n        \"Box\"\n    ]\n    cur_dir = os.path.dirname(os.path.abspath(__file__))\n    gr.Markdown(title_markdown)\n    with gr.Row():\n        with gr.Column(min_width=300, scale=0.4):\n            model_selector = gr.Dropdown(\n                choices=models,\n                value=models[0] if len(models) > 0 else \"\",\n                interactive=True,\n                show_label=False,\n                container=False)\n            img = gr.Image(\n                type=\"numpy\",\n                # label=\"Image\",\n                height=220,\n                tool=\"sketch\", \n                interactive=True\n            )\n            img_upload_btn = gr.Button(\"Submit Image\")\n            with gr.Row():\n                inter_upload_btn = gr.Button(\"Submit Interaction\")\n                interaction_selector = gr.Dropdown(\n                    choices=interactions,\n                    value=interactions[0] if len(interactions) > 0 else \"\",\n                    interactive=True,\n                    show_label=False,\n                    container=False,\n                )\n\n            with gr.Row():\n                temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label=\"Temperature\")\n                threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.05, step=0.1, interactive=True, label=\"Threshold\")\n    \n            # possibly a bug in gradio: https://github.com/gradio-app/gradio/issues/3623\n            gr.Examples(examples=[\n                [f\"{cur_dir}/examples/meeting.jpg\", \"Describe the scene in detail. (with grounding)\"],\n                [f\"{cur_dir}/examples/pizza.jpg\", \"Describe the scene in detail. (with grounding)\"],\n                \n            ], inputs=[img, txt], label=\"Grounded Description Examples: \")\n            gr.Examples(examples=[\n                [f\"{cur_dir}/examples/cow_motor.jpg\", \"Where is the object % and what is it doing?\"],\n                [f\"{cur_dir}/examples/dog_sleep.jpg\", \"What is the object % doing and why?\"],\n                \n            ], inputs=[img, txt], label=\"Visual Prompt Examples (Please draw clicks or boxes on the woman and dog for the two examples, respectively.): \")\n\n        with gr.Column():\n            chatbot = gr.Chatbot(\n                [],\n                elem_id=\"chatbot\",\n                bubble_full_width=False,\n                height=598\n                # avatar_images=(None, (os.path.join(os.path.dirname(__file__), \"avatar.png\"))),\n            )\n\n            with gr.Row():\n                with gr.Column(scale=8):\n                    txt.render()\n                with gr.Column(scale=1, min_width=60):\n                    submit_btn = gr.Button(value=\"Send\")\n            #TODO: Enable these buttons.\n            with gr.Row():\n                upvote_btn = gr.Button(value=\"👍  Upvote\", interactive=True)\n                downvote_btn = gr.Button(value=\"👎  Downvote\", interactive=True)\n                flag_btn = gr.Button(value=\"⚠️  Flag\", interactive=True)\n                #stop_btn = gr.Button(value=\"⏹️  Stop Generation\", interactive=True)\n                regenerate_btn = gr.Button(value=\"🔄  Regenerate\", interactive=True)\n                clear_btn = gr.Button(value=\"🗑️  Clear history\", interactive=True)\n            gr.Markdown(tips_markdown)\n            gr.Markdown(tos_markdown)\n            gr.Markdown(learn_more_markdown)\n    if os.path.isfile(\"gradio_demo/examples/demo_grounding.mp4\"):  # only online demo\n        gr.Markdown(\"-----------------------------------\")\n        gr.Markdown(\"## User's Guidance\")\n        with gr.Row():\n            with gr.Column():\n                gr.Markdown(\"### Grounded Visual Chat\")\n                gr.Video(value=\"gradio_demo/examples/demo_grounding.mp4\")\n            with gr.Column():\n                gr.Markdown(\"### Visual Prompt (Click)\")\n                gr.Video(value=\"gradio_demo/examples/demo_inter_click.mp4\")\n            with gr.Column():\n                gr.Markdown(\"### Visual Prompt (Box)\")\n                gr.Video(value=\"gradio_demo/examples/demo_inter_box.mp4\")\n    txt.submit(add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then(\n        bot, \n        chatbot, chatbot, \n        api_name=\"bot_text_response\"\n    )\n    submit_btn.click(fn=add_text, inputs=[chatbot, txt, img, threshold, temperature, interaction_selector], outputs=[chatbot, txt]).then(\n        bot, \n        chatbot, chatbot, \n        api_name=\"submit_text\"\n    )\n    img_upload_btn.click(fn=add_image, inputs=[chatbot, img], outputs=[chatbot], api_name=\"upload_image\")\n    inter_upload_btn.click(fn=add_interaction_click, inputs=[chatbot, img, interaction_selector], outputs=[chatbot], api_name=\"upload_inter\")\n    # buttons\n    clear_btn.click(fn=clear_history, inputs=[chatbot, txt, img], outputs=[chatbot, txt, img], api_name=\"clear_all\")\n    regenerate_btn.click(fn=clear_response, inputs=[chatbot], outputs=[chatbot, txt], api_name=\"clear_last_response\").then(\n        add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then(\n        bot, \n        chatbot, chatbot, \n        api_name=\"regenerate_response\"\n    )\n    upvote_btn.click(fn=upvote_one, inputs=[], outputs=[], api_name=\"upvote_one\")\n    downvote_btn.click(fn=downvote_one, inputs=[], outputs=[], api_name=\"downvote_one\")\n    flag_btn.click(fn=flag_one, inputs=[], outputs=[], api_name=\"flag_one\")\n\ndemo.queue()\nif __name__ == \"__main__\":\n    import argparse\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\"--server_name\", default=\"0.0.0.0\", type=str)\n    argparser.add_argument(\"--port\", default=12124, type=str)\n    argparser.add_argument(\"--model_path\", default=\"\", type=str)\n    argparser.add_argument(\"--path_vision_cfg\", default=\"configs/openseed/openseed_swint_lang_joint_2st_v2_data_end_with_interaction.yaml\", type=str)\n    argparser.add_argument(\"--path_inter_cfg\", default=\"configs/semsam/idino_swint_1_part_data_llm_ref_feat_all_16_det_pretrainv1.yaml\", type=str)\n    args = argparser.parse_args()\n    model_path = args.model_path\n    colors = generate_distinct_colors(20)\n    if not os.path.exists(\"./gradio_demo/tmp_files\"):\n        os.makedirs(\"./gradio_demo/tmp_files\")\n    our_chatbot = InferenceDemo(args.model_path, args.path_vision_cfg, args.path_inter_cfg)\n    demo.launch(server_name=args.server_name, server_port=int(args.port))\n    "
  },
  {
    "path": "gradio_demo/__init__.py",
    "content": ""
  },
  {
    "path": "llava/__init__.py",
    "content": "from .model import LlavaLlamaForCausalLM\n"
  },
  {
    "path": "llava/constants.py",
    "content": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \".\"\n\n# Model Constants\nIGNORE_INDEX = -100\nIMAGE_TOKEN_INDEX = -200\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\nDEFAULT_IM_START_TOKEN = \"<im_start>\"\nDEFAULT_IM_END_TOKEN = \"<im_end>\"\n"
  },
  {
    "path": "llava/conversation.py",
    "content": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\n\n\nclass SeparatorStyle(Enum):\n    \"\"\"Different separator style.\"\"\"\n    SINGLE = auto()\n    TWO = auto()\n    MPT = auto()\n    PLAIN = auto()\n    LLAMA_2 = auto()\n\n\n@dataclasses.dataclass\nclass Conversation:\n    \"\"\"A class that keeps all conversation history.\"\"\"\n    system: str\n    roles: List[str]\n    messages: List[List[str]]\n    offset: int\n    sep_style: SeparatorStyle = SeparatorStyle.SINGLE\n    sep: str = \"###\"\n    sep2: str = None\n    version: str = \"Unknown\"\n\n    skip_next: bool = False\n\n    def get_prompt(self):\n        messages = self.messages\n        if len(messages) > 0 and type(messages[0][1]) is tuple:\n            messages = self.messages.copy()\n            init_role, init_msg = messages[0].copy()\n            init_msg = init_msg[0].replace(\"<image>\", \"\").strip()\n            if 'mmtag' in self.version:\n                messages[0] = (init_role, init_msg)\n                messages.insert(0, (self.roles[0], \"<Image><image></Image>\"))\n                messages.insert(1, (self.roles[1], \"Received.\"))\n            else:\n                messages[0] = (init_role, \"<image>\\n\" + init_msg)\n\n        if self.sep_style == SeparatorStyle.SINGLE:\n            ret = self.system + self.sep\n            for role, message in messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + \": \" + message + self.sep\n                else:\n                    ret += role + \":\"\n        elif self.sep_style == SeparatorStyle.TWO:\n            seps = [self.sep, self.sep2]\n            ret = self.system + seps[0]\n            for i, (role, message) in enumerate(messages):\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + \": \" + message + seps[i % 2]\n                else:\n                    ret += role + \":\"\n        elif self.sep_style == SeparatorStyle.MPT:\n            ret = self.system + self.sep\n            for role, message in messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n        elif self.sep_style == SeparatorStyle.LLAMA_2:\n            wrap_sys = lambda msg: f\"<<SYS>>\\n{msg}\\n<</SYS>>\\n\\n\"\n            wrap_inst = lambda msg: f\"[INST] {msg} [/INST]\"\n            ret = \"\"\n\n            for i, (role, message) in enumerate(messages):\n                if i == 0:\n                    assert message, \"first message should not be none\"\n                    assert role == self.roles[0], \"first message should come from user\"\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    if i == 0: message = wrap_sys(self.system) + message\n                    if i % 2 == 0:\n                        message = wrap_inst(message)\n                        ret += self.sep + message\n                    else:\n                        ret += \" \" + message + \" \" + self.sep2\n                else:\n                    ret += \"\"\n            ret = ret.lstrip(self.sep)\n        elif self.sep_style == SeparatorStyle.PLAIN:\n            seps = [self.sep, self.sep2]\n            ret = self.system\n            for i, (role, message) in enumerate(messages):\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += message + seps[i % 2]\n                else:\n                    ret += \"\"\n        else:\n            raise ValueError(f\"Invalid style: {self.sep_style}\")\n\n        return ret\n\n    def append_message(self, role, message):\n        self.messages.append([role, message])\n\n    def get_images(self, return_pil=False):\n        images = []\n        for i, (role, msg) in enumerate(self.messages[self.offset:]):\n            if i % 2 == 0:\n                if type(msg) is tuple:\n                    import base64\n                    from io import BytesIO\n                    from PIL import Image\n                    msg, image, image_process_mode = msg\n                    if image_process_mode == \"Pad\":\n                        def expand2square(pil_img, background_color=(122, 116, 104)):\n                            width, height = pil_img.size\n                            if width == height:\n                                return pil_img\n                            elif width > height:\n                                result = Image.new(pil_img.mode, (width, width), background_color)\n                                result.paste(pil_img, (0, (width - height) // 2))\n                                return result\n                            else:\n                                result = Image.new(pil_img.mode, (height, height), background_color)\n                                result.paste(pil_img, ((height - width) // 2, 0))\n                                return result\n                        image = expand2square(image)\n                    elif image_process_mode == \"Crop\":\n                        pass\n                    elif image_process_mode == \"Resize\":\n                        image = image.resize((336, 336))\n                    else:\n                        raise ValueError(f\"Invalid image_process_mode: {image_process_mode}\")\n                    max_hw, min_hw = max(image.size), min(image.size)\n                    aspect_ratio = max_hw / min_hw\n                    max_len, min_len = 800, 400\n                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))\n                    longest_edge = int(shortest_edge * aspect_ratio)\n                    W, H = image.size\n                    if H > W:\n                        H, W = longest_edge, shortest_edge\n                    else:\n                        H, W = shortest_edge, longest_edge\n                    image = image.resize((W, H))\n                    if return_pil:\n                        images.append(image)\n                    else:\n                        buffered = BytesIO()\n                        image.save(buffered, format=\"PNG\")\n                        img_b64_str = base64.b64encode(buffered.getvalue()).decode()\n                        images.append(img_b64_str)\n        return images\n\n    def to_gradio_chatbot(self):\n        ret = []\n        for i, (role, msg) in enumerate(self.messages[self.offset:]):\n            if i % 2 == 0:\n                if type(msg) is tuple:\n                    import base64\n                    from io import BytesIO\n                    msg, image, image_process_mode = msg\n                    max_hw, min_hw = max(image.size), min(image.size)\n                    aspect_ratio = max_hw / min_hw\n                    max_len, min_len = 800, 400\n                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))\n                    longest_edge = int(shortest_edge * aspect_ratio)\n                    W, H = image.size\n                    if H > W:\n                        H, W = longest_edge, shortest_edge\n                    else:\n                        H, W = shortest_edge, longest_edge\n                    image = image.resize((W, H))\n                    buffered = BytesIO()\n                    image.save(buffered, format=\"JPEG\")\n                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()\n                    img_str = f'<img src=\"data:image/png;base64,{img_b64_str}\" alt=\"user upload image\" />'\n                    ret.append([img_str, None])\n                    msg = msg.replace('<image>', '').strip()\n                    if len(msg) > 0:\n                        ret.append([msg, None])\n                else:\n                    ret.append([msg, None])\n            else:\n                ret[-1][-1] = msg\n        return ret\n\n    def copy(self):\n        return Conversation(\n            system=self.system,\n            roles=self.roles,\n            messages=[[x, y] for x, y in self.messages],\n            offset=self.offset,\n            sep_style=self.sep_style,\n            sep=self.sep,\n            sep2=self.sep2,\n            version=self.version)\n\n    def dict(self):\n        if len(self.get_images()) > 0:\n            return {\n                \"system\": self.system,\n                \"roles\": self.roles,\n                \"messages\": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],\n                \"offset\": self.offset,\n                \"sep\": self.sep,\n                \"sep2\": self.sep2,\n            }\n        return {\n            \"system\": self.system,\n            \"roles\": self.roles,\n            \"messages\": self.messages,\n            \"offset\": self.offset,\n            \"sep\": self.sep,\n            \"sep2\": self.sep2,\n        }\n\n\nconv_vicuna_v0 = Conversation(\n    system=\"A chat between a curious human and an artificial intelligence assistant. \"\n           \"The assistant gives helpful, detailed, and polite answers to the human's questions.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"What are the key differences between renewable and non-renewable energy sources?\"),\n        (\"Assistant\",\n            \"Renewable energy sources are those that can be replenished naturally in a relatively \"\n            \"short amount of time, such as solar, wind, hydro, geothermal, and biomass. \"\n            \"Non-renewable energy sources, on the other hand, are finite and will eventually be \"\n            \"depleted, such as coal, oil, and natural gas. Here are some key differences between \"\n            \"renewable and non-renewable energy sources:\\n\"\n            \"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable \"\n            \"energy sources are finite and will eventually run out.\\n\"\n            \"2. Environmental impact: Renewable energy sources have a much lower environmental impact \"\n            \"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, \"\n            \"and other negative effects.\\n\"\n            \"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically \"\n            \"have lower operational costs than non-renewable sources.\\n\"\n            \"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote \"\n            \"locations than non-renewable sources.\\n\"\n            \"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different \"\n            \"situations and needs, while non-renewable sources are more rigid and inflexible.\\n\"\n            \"6. Sustainability: Renewable energy sources are more sustainable over the long term, while \"\n            \"non-renewable sources are not, and their depletion can lead to economic and social instability.\\n\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nconv_vicuna_v1 = Conversation(\n    system=\"A chat between a curious user and an artificial intelligence assistant. \"\n    \"The assistant gives helpful, detailed, and polite answers to the user's questions.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"v1\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n)\n\nconv_llama_2 = Conversation(\n    system=\"\"\"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\"\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"llama_v2\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.LLAMA_2,\n    sep=\"<s>\",\n    sep2=\"</s>\",\n)\n\nconv_llava_llama_2 = Conversation(\n    system=\"You are a helpful language and vision assistant. \"\n           \"You are able to understand the visual content that the user provides, \"\n           \"and assist the user with a variety of tasks using natural language.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"llama_v2\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.LLAMA_2,\n    sep=\"<s>\",\n    sep2=\"</s>\",\n)\n\nconv_mpt = Conversation(\n    system=\"\"\"<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\"\"\",\n    roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n    version=\"mpt\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.MPT,\n    sep=\"<|im_end|>\",\n)\n\nconv_llava_plain = Conversation(\n    system=\"\",\n    roles=(\"\", \"\"),\n    messages=(\n    ),\n    offset=0,\n    sep_style=SeparatorStyle.PLAIN,\n    sep=\"\\n\",\n)\n\nconv_llava_v0 = Conversation(\n    system=\"A chat between a curious human and an artificial intelligence assistant. \"\n           \"The assistant gives helpful, detailed, and polite answers to the human's questions.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"Hi!\"),\n        (\"Assistant\", \"Hi there! How can I help you today?\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nconv_llava_v0_mmtag = Conversation(\n    system=\"A chat between a curious user and an artificial intelligence assistant. \"\n           \"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\n           \"The visual content will be provided with the following format: <Image>visual content</Image>.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n    ),\n    offset=0,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n    version=\"v0_mmtag\",\n)\n\nconv_llava_v1 = Conversation(\n    system=\"A chat between a curious human and an artificial intelligence assistant. \"\n           \"The assistant gives helpful, detailed, and polite answers to the human's questions.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"v1\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n)\n\nconv_llava_v1_mmtag = Conversation(\n    system=\"A chat between a curious user and an artificial intelligence assistant. \"\n           \"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\n           \"The visual content will be provided with the following format: <Image>visual content</Image>.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n    version=\"v1_mmtag\",\n)\n\ndefault_conversation = conv_vicuna_v0\nconv_templates = {\n    \"default\": conv_vicuna_v0,\n    \"v0\": conv_vicuna_v0,\n    \"v1\": conv_vicuna_v1,\n    \"vicuna_v1\": conv_vicuna_v1,\n    \"llama_2\": conv_llama_2,\n\n    \"plain\": conv_llava_plain,\n    \"v0_plain\": conv_llava_plain,\n    \"llava_v0\": conv_llava_v0,\n    \"v0_mmtag\": conv_llava_v0_mmtag,\n    \"llava_v1\": conv_llava_v1,\n    \"v1_mmtag\": conv_llava_v1_mmtag,\n    \"llava_llama_2\": conv_llava_llama_2,\n\n    \"mpt\": conv_mpt,\n}\n\n\nif __name__ == \"__main__\":\n    print(default_conversation.get_prompt())\n"
  },
  {
    "path": "llava/eval/LLaVA_G_Eval.py",
    "content": "import os\nimport cv2\nimport json\nimport torch\nimport collections\nimport transformers\nimport numpy as np\nfrom llava.model import *\nfrom typing import Dict\nfrom llava import conversation as conversation_lib\nfrom tqdm import tqdm\nfrom detectron2.utils.file_io import PathManager\nfrom llava.mm_utils import tokenizer_image_token\nfrom llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN\nfrom llava.eval.llava_mapper import COCOInstanceNewBaselineDatasetMapper as LLAVAInstanceNewBaselineDatasetMapper\ngrounding_start=\"<g_s>\"\ngrounding_end=\"<g_e>\"\nSEG_TOKEN=\"<seg>\"\nBOX_TOKEN=\"#B#\"\nMARKER_TOKEN=\"#M#\"\n\ndef load_jsonl_file(path_jsonl):\n        import jsonlines\n        data = []\n        with jsonlines.open(path_jsonl, \"r\") as reader:\n            for obj in reader:\n                data.append(obj)\n        return data\ndef save_jsonl_file(data, path_save):\n    import jsonlines\n    with jsonlines.open(path_save, \"w\") as writer:\n        for item in data: \n            writer.write(item) \ndef load_benchmark(image_root, path_benchmark):\n\n    data = load_jsonl_file(path_benchmark)\n    ret = []\n    for d in data:\n        image_name = d[\"image\"]\n        image_id = int(image_name.split(\".\")[0])\n        image_file = os.path.join(image_root, \"COCO_val2014_\" + image_name)\n        # conv = d[\"conversations\"]\n        conv = [\n        {\n            \"from\": \"human\", \n            \"value\": d[\"text\"]\n        },\n        {\n            \"from\": \"gpt\", \n            \"value\": \"Placeholder.\"\n        }\n        ]\n        conv[0][\"value\"] = DEFAULT_IMAGE_TOKEN + \" \" + conv[0][\"value\"] + \" (with grounding)\"\n        ret.append(\n            {\n                \"file_name\": image_file,\n                \"image_id\": image_id,\n                # \"grounding_info\": None,\n                \"conversations\": [[conv, None]],\n                \"question_id\": d[\"question_id\"]\n            }\n        )\n    assert len(ret), f\"No images found in {image_root}!\"\n    assert PathManager.isfile(ret[0][\"file_name\"]), ret[0][\"file_name\"]\n    return ret\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conv_prompt = conv.get_prompt()\n        conv_prompt = conv_prompt.split(\"ASSISTANT: \")[0] + \"ASSISTANT:\"\n        conversations.append(conv_prompt)\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\nclass Evaluator_MM:\n    def __init__(self, \n                 model_path, \n                 path_vision_model_cfg=None):\n        model_paths = model_path.split(\"/\")\n        if model_paths[-1].startswith('checkpoint-'):\n            self.model_name = model_paths[-2] + \"_\" + model_paths[-1]\n        else:\n            self.model_name = model_paths[-1]\n        print(\"1. Constructing model...\")\n        self.tokenizer, self.model, self.image_processor, self.context_len = self.construct_model(\n            model_path=model_path,\n            model_name=self.model_name,\n        )\n        print(\"   Continue...\")\n        self.construct_vision_model(path_vision_model_cfg)\n        print(\"Done.\")\n        self.image_processor=self.model.get_vision_tower().image_processor\n        print(\"2. Loading Parameters...\")\n        self.load_parameters(model_path)\n        print(\"Done.\")\n        self.model.eval()\n\n        conversation_lib.default_conversation = conversation_lib.conv_templates[\"v1\"]\n        self.data_mapper  = LLAVAInstanceNewBaselineDatasetMapper(self.cfg_vision_model, False, tokenizer=self.tokenizer, image_processor=self.image_processor, preprocess=preprocess_v1)\n    \n    def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map=\"auto\"):\n        import os\n        import shutil\n        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig\n        from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n        kwargs = {\"device_map\": device_map}\n\n        if load_8bit:\n            kwargs['load_in_8bit'] = True\n        elif load_4bit:\n            kwargs['load_in_4bit'] = True\n            kwargs['quantization_config'] = BitsAndBytesConfig(\n                load_in_4bit=True,\n                bnb_4bit_compute_dtype=torch.float16,\n                bnb_4bit_use_double_quant=True,\n                bnb_4bit_quant_type='nf4'\n            )\n        else:\n            kwargs['torch_dtype'] = torch.float16\n\n        if 'llava' in model_name.lower():\n            # Load LLaVA model\n            if 'lora' in model_name.lower() and model_base is not None:\n                lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                print('Loading LLaVA from base model...')\n                model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)\n                token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features\n                if model.lm_head.weight.shape[0] != token_num:\n                    model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n                    model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n\n                print('Loading additional LLaVA weights...')\n                if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):\n                    non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')\n                else:\n                    # this is probably from HF Hub\n                    from huggingface_hub import hf_hub_download\n                    def load_from_hf(repo_id, filename, subfolder=None):\n                        cache_file = hf_hub_download(\n                            repo_id=repo_id,\n                            filename=filename,\n                            subfolder=subfolder)\n                        return torch.load(cache_file, map_location='cpu')\n                    non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')\n                non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}\n                if any(k.startswith('model.model.') for k in non_lora_trainables):\n                    non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}\n                model.load_state_dict(non_lora_trainables, strict=False)\n\n                from peft import PeftModel\n                print('Loading LoRA weights...')\n                model = PeftModel.from_pretrained(model, model_path)\n                print('Merging LoRA weights...')\n                model = model.merge_and_unload()\n                print('Model is loaded...')\n            elif model_base is not None:\n                # this may be mm projector only\n                print('Loading LLaVA from base model...')\n                if 'mpt' in model_name.lower():\n                    if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):\n                        shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))\n                    tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)\n                    cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n                    model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                    cfg_pretrained = AutoConfig.from_pretrained(model_path)\n                    model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n\n                mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')\n                mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}\n                model.load_state_dict(mm_projector_weights, strict=False)\n            else:\n                if 'mpt' in model_name.lower():\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                    model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                    model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n        else:\n            # Load language model\n            if model_base is not None:\n                # PEFT model\n                from peft import PeftModel\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=\"auto\")\n                print(f\"Loading LoRA weights from {model_path}\")\n                model = PeftModel.from_pretrained(model, model_path)\n                print(f\"Merging weights\")\n                model = model.merge_and_unload()\n                print('Convert to FP16...')\n                model.to(torch.float16)\n            else:\n                use_fast = False\n                if 'mpt' in model_name.lower():\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                    model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                    model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n\n        image_processor = None\n\n        if 'llava' in model_name.lower():\n            mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n            mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n            if mm_use_im_patch_token:\n                tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n            if mm_use_im_start_end:\n                tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n            model.resize_token_embeddings(len(tokenizer))\n\n            vision_tower = model.get_vision_tower()\n            if not vision_tower.is_loaded:\n                vision_tower.load_model()\n            vision_tower.to(device='cuda', dtype=torch.float16)\n            image_processor = vision_tower.image_processor\n\n        if hasattr(model.config, \"max_sequence_length\"):\n            context_len = model.config.max_sequence_length\n        else:\n            context_len = 2048\n\n        return tokenizer, model, image_processor, context_len\n    def construct_vision_model(self, path_vision_model_cfg):\n        from detectron2.config import LazyConfig\n        from llava.model.openseed import build_model\n        from llava.model.openseed.BaseModel import BaseModel\n\n        def get_config_from_name(cfg, dataset_name=\"flickr\"):\n            # adjust config according to dataset, flickr by default\n            if 'sam' in dataset_name:\n                cfg.update(cfg['SAM'])\n                return cfg\n            elif 'flickr' in dataset_name:\n                cfg.update(cfg['flickr'])\n                return cfg\n            elif 'coco_instruct_train' in dataset_name:\n                cfg.update(cfg['coco_instruct'])\n                return cfg\n            elif 'lisa' in dataset_name:\n                cfg.update(cfg['LISA_REF'])\n                return cfg\n            elif 'llava' in dataset_name:\n                cfg.update(cfg['llava'])\n                return cfg\n            elif 'vg' in dataset_name:\n                cfg.update(cfg['vg'])\n                return cfg\n            elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name:\n                cfg.update(cfg['part'])\n                return cfg\n            elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name :\n                cfg.update(cfg['PSACAL_PART'])\n                return cfg\n            elif 'coco' in dataset_name and 'refonly' in dataset_name:\n                # if 'COCO' in cfg.keys():\n                cfg.update(cfg['COCO_REF'])\n                return cfg\n            elif 'refcoco' in dataset_name or \"flickr_val\" in dataset_name:\n                cfg.update(cfg['REF'])\n                return cfg\n            elif 'coco' in dataset_name:\n                if 'COCO' in cfg.keys():\n                    cfg.update(cfg['COCO'])\n                return cfg\n            elif \"mapillary\" in dataset_name:\n                if 'MAPILLARY' in cfg.keys():\n                    cfg.update(cfg['MAPILLARY'])\n                return cfg\n            elif 'ade' in dataset_name:\n                if 'ADE20K' in cfg.keys():\n                    cfg.update(cfg['ADE20K'])\n                return cfg\n            elif 'imagenet' in dataset_name:\n                if 'IMAGENET' in cfg.keys():\n                    cfg.update(cfg['IMAGENET'])\n                return cfg\n            elif 'vlp' in dataset_name:\n                cfg.update(cfg['VLP'])\n                return cfg\n            elif 'sun' in dataset_name:\n                cfg.update(cfg['SUN'])\n                return cfg\n            elif 'object365' in dataset_name:\n                cfg.update(cfg['OBJECT365'])\n                return cfg\n            elif 'scan' in dataset_name:\n                cfg.update(cfg['SCAN'])\n                return cfg\n            elif 'cityscape' in dataset_name:\n                cfg.update(cfg['CITY'])\n                return cfg\n            elif 'bdd' in dataset_name:\n                cfg.update(cfg['BDD'])\n                return cfg\n            else:\n                assert False, \"dataset not support.\"\n        self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg)\n        vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model))\n        vision_model.eval()\n        self.model.seg_model = vision_model\n        self.model.seg_model.model = self.model.seg_model.model.to(self.model.device)\n        # print(\"Configuring for Dataset Mapper ...\")\n        self.cfg_vision_model = get_config_from_name(self.cfg_vision_model)\n    \n    def load_parameters(self, path_model):\n        print(\"Loading Whole Model ...\")\n        loaded_dict = dict()\n        for model_file in os.listdir(path_model):\n            if model_file.endswith('.bin') and model_file.startswith('pytorch_model'):\n                loaded_dict.update(torch.load(os.path.join(path_model, model_file), map_location='cpu'))\n        self.model.load_state_dict(loaded_dict, strict=True)\n    @torch.inference_mode()\n    def evaluate_sample(self, input_data, get_box=True, get_mask=False):\n        text, boxes, masks = self.model.forward_eval(input_data)\n        returns = [text,]\n        if get_box:\n            returns.append(boxes)\n        if get_mask:\n            returns.append(masks)\n        \n        return returns\n\nclass Evaluator_MM_Inter(Evaluator_MM):\n    def __init__(self, model_path, path_vision_model_cfg=None, path_inter_model_cfg=None):\n        self.path_inter_model_cfg = path_inter_model_cfg\n        super().__init__(model_path, path_vision_model_cfg)\n    def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map=\"auto\"):\n        import os\n        import shutil\n        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig\n        from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n        kwargs = {\"device_map\": device_map}\n\n        if load_8bit:\n            kwargs['load_in_8bit'] = True\n        elif load_4bit:\n            kwargs['load_in_4bit'] = True\n            kwargs['quantization_config'] = BitsAndBytesConfig(\n                load_in_4bit=True,\n                bnb_4bit_compute_dtype=torch.float16,\n                bnb_4bit_use_double_quant=True,\n                bnb_4bit_quant_type='nf4'\n            )\n        else:\n            kwargs['torch_dtype'] = torch.float16\n\n        if 'llava' in model_name.lower():\n            # Load LLaVA model\n            if 'lora' in model_name.lower() and model_base is not None:\n                lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                print('Loading LLaVA from base model...')\n                model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)\n                token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features\n                if model.lm_head.weight.shape[0] != token_num:\n                    model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n                    model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n\n                print('Loading additional LLaVA weights...')\n                if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):\n                    non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')\n                else:\n                    # this is probably from HF Hub\n                    from huggingface_hub import hf_hub_download\n                    def load_from_hf(repo_id, filename, subfolder=None):\n                        cache_file = hf_hub_download(\n                            repo_id=repo_id,\n                            filename=filename,\n                            subfolder=subfolder)\n                        return torch.load(cache_file, map_location='cpu')\n                    non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')\n                non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}\n                if any(k.startswith('model.model.') for k in non_lora_trainables):\n                    non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}\n                model.load_state_dict(non_lora_trainables, strict=False)\n\n                from peft import PeftModel\n                print('Loading LoRA weights...')\n                model = PeftModel.from_pretrained(model, model_path)\n                print('Merging LoRA weights...')\n                model = model.merge_and_unload()\n                print('Model is loaded...')\n            elif model_base is not None:\n                # this may be mm projector only\n                print('Loading LLaVA from base model...')\n                if 'mpt' in model_name.lower():\n                    if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):\n                        shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))\n                    tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)\n                    cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n                    model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                    cfg_pretrained = AutoConfig.from_pretrained(model_path)\n                    model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n\n                mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')\n                mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}\n                model.load_state_dict(mm_projector_weights, strict=False)\n            else:\n                if 'mpt' in model_name.lower():\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                    model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                    model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n        else:\n            # Load language model\n            if model_base is not None:\n                # PEFT model\n                from peft import PeftModel\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=\"auto\")\n                print(f\"Loading LoRA weights from {model_path}\")\n                model = PeftModel.from_pretrained(model, model_path)\n                print(f\"Merging weights\")\n                model = model.merge_and_unload()\n                print('Convert to FP16...')\n                model.to(torch.float16)\n            else:\n                use_fast = False\n                if 'mpt' in model_name.lower():\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                    model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)\n                else:\n                    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                    model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n\n        image_processor = None\n\n        if 'llava' in model_name.lower():\n            mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n            mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n            if mm_use_im_patch_token:\n                tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n            if mm_use_im_start_end:\n                tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n            model.resize_token_embeddings(len(tokenizer))\n\n            vision_tower = model.get_vision_tower()\n            if not vision_tower.is_loaded:\n                vision_tower.load_model()\n            vision_tower.to(device='cuda', dtype=torch.float16)\n            image_processor = vision_tower.image_processor\n\n        if hasattr(model.config, \"max_sequence_length\"):\n            context_len = model.config.max_sequence_length\n        else:\n            context_len = 2048\n\n        return tokenizer, model, image_processor, context_len\n    \n    def construct_vision_model(self, path_vision_model_cfg):\n        from detectron2.config import LazyConfig\n        from llava.model.openseed import build_model\n        from llava.model.openseed.BaseModel import BaseModel\n\n        def get_config_from_name(cfg, dataset_name=\"flickr\"):\n            # adjust config according to dataset, flickr by default\n            if 'sam' in dataset_name:\n                cfg.update(cfg['SAM'])\n                return cfg\n            elif 'flickr' in dataset_name:\n                cfg.update(cfg['flickr'])\n                return cfg\n            elif 'coco_instruct_train' in dataset_name:\n                cfg.update(cfg['coco_instruct'])\n                return cfg\n            elif 'lisa' in dataset_name:\n                cfg.update(cfg['LISA_REF'])\n                return cfg\n            elif 'llava' in dataset_name:\n                cfg.update(cfg['llava'])\n                return cfg\n            elif 'vg' in dataset_name:\n                cfg.update(cfg['vg'])\n                return cfg\n            elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name:\n                cfg.update(cfg['part'])\n                return cfg\n            elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name :\n                cfg.update(cfg['PSACAL_PART'])\n                return cfg\n            elif 'coco' in dataset_name and 'refonly' in dataset_name:\n                # if 'COCO' in cfg.keys():\n                cfg.update(cfg['COCO_REF'])\n                return cfg\n            elif 'refcoco' in dataset_name or \"flickr_val\" in dataset_name:\n                cfg.update(cfg['REF'])\n                return cfg\n            elif 'coco' in dataset_name:\n                if 'COCO' in cfg.keys():\n                    cfg.update(cfg['COCO'])\n                return cfg\n            elif \"mapillary\" in dataset_name:\n                if 'MAPILLARY' in cfg.keys():\n                    cfg.update(cfg['MAPILLARY'])\n                return cfg\n            elif 'ade' in dataset_name:\n                if 'ADE20K' in cfg.keys():\n                    cfg.update(cfg['ADE20K'])\n                return cfg\n            elif 'imagenet' in dataset_name:\n                if 'IMAGENET' in cfg.keys():\n                    cfg.update(cfg['IMAGENET'])\n                return cfg\n            elif 'vlp' in dataset_name:\n                cfg.update(cfg['VLP'])\n                return cfg\n            elif 'sun' in dataset_name:\n                cfg.update(cfg['SUN'])\n                return cfg\n            elif 'object365' in dataset_name:\n                cfg.update(cfg['OBJECT365'])\n                return cfg\n            elif 'scan' in dataset_name:\n                cfg.update(cfg['SCAN'])\n                return cfg\n            elif 'cityscape' in dataset_name:\n                cfg.update(cfg['CITY'])\n                return cfg\n            elif 'bdd' in dataset_name:\n                cfg.update(cfg['BDD'])\n                return cfg\n            else:\n                assert False, \"dataset not support.\"\n        self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg)\n        vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model))\n        vision_model.eval()\n        self.model.seg_model = vision_model\n        self.model.seg_model.model = self.model.seg_model.model.to(self.model.device)\n\n        self.cfg_inter_model = LazyConfig.load(self.path_inter_model_cfg)\n        self.model.initialize_interactive_modules(self.cfg_inter_model)\n        self.model.interactive_model.model = self.model.interactive_model.model.to(self.model.device)\n        # print(\"Configuring for Dataset Mapper ...\")\n        self.cfg_vision_model = get_config_from_name(self.cfg_vision_model)\n    @torch.inference_mode()\n    def evaluate_sample(self, input_data):\n        text, boxes, masks, mask_inter = self.model.forward_eval(input_data)\n\n        return text, boxes, masks, mask_inter\n    \ndef formatting(text, boxes, question_id):\n    def find_start_idxes(sentence, word):\n        window_size = len(word)\n        start_indexes = []\n        assert len(sentence) > window_size\n        if sentence == window_size:\n            return [0]\n        for start_index in range(len(sentence) - window_size+1):\n            if sentence[start_index: start_index + window_size] == word:\n                start_indexes.append(start_index)\n        return start_indexes\n    def extract_text(sentence):\n        # Use regular expression to find and extract the text and number\n        import re\n        pattern = r\"<g_s>|<g_e>\"\n        cleaned_text = re.sub(pattern, '', sentence)\n        return cleaned_text\n    def multiboxes_to_str(boxes):\n        boxes_text = []\n        for box in boxes:\n            boxes_text.append(list_to_str(box))\n        output_string = \";\".join(boxes_text)\n        return output_string.replace(\"];[\", \";\")\n    def list_to_str(list_):\n        list_str = [str(round(aa, 3)) for aa in list_]\n        return \"[\" + \",\".join(list_str) + \"]\"\n    def format_sentence(splitted_sentence):\n        joint_sentence = \" \".join(splitted_sentence)\n        return joint_sentence\n    \n    text_pure = \"\"\n    text_boxes = \"\"\n    boxes_pure = []\n    \n    number = 0\n    seg_start_index = find_start_idxes(text, \"<seg>\")\n    if len(seg_start_index) > 0:\n        # text = text[:tail_start_index[0]]\n        subtexts = text.split(\" <seg>\")\n        for subtext in subtexts:\n            if \"<g_s>\" in subtext:\n                # subtext += \"<g_e>\"\n                start_idx = find_start_idxes(subtext, \"<g_s>\")[0]\n                text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())])\n                text_boxes = format_sentence([text_boxes, format_sentence(subtext[:start_idx].split())])\n                text_ = extract_text(subtext[start_idx:])\n                text_pure = format_sentence([text_pure, format_sentence(text_.split())])\n                if number >= len(boxes):\n                    print(\"Error, There should be a wrong prediction.\")\n                    text_boxes = format_sentence([text_boxes, format_sentence(text_.split())])\n                    number += 1\n                    continue\n                text_boxes = format_sentence([text_boxes, format_sentence(text_.split()) + multiboxes_to_str(boxes[number].cpu().tolist())])\n                boxes_pure.append(multiboxes_to_str(boxes[number].cpu().tolist()))\n                number += 1\n            else:\n                text_pure = format_sentence([text_pure, format_sentence(subtext.split())])\n                text_boxes = format_sentence([text_boxes, format_sentence(subtext.split())])\n        return {\n            \"question_id\": question_id,\n            \"text\": text_pure, \n            \"text_boxes\": text_boxes, \n            \"boxes\": boxes_pure,\n        }\n    else:\n        return {\n            \"question_id\": question_id,\n            \"text\": text, \n            \"text_boxes\": text, \n            \"boxes\": []\n        }\n\ndef evaluate_(path_benchmarks, dir_image, evaluator, matching_threshold):\n    def unresize_box(box, width, height, size):\n        # ori_size = max(width, height)\n        # ratio = ori_size / size\n        ratio = min(width, height) / max(width, height)\n        if width > height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n            box[:, 1] = box[:, 1] / ratio\n            box[:, 3] = box[:, 3] / ratio\n        elif width < height:  # then the height dimension is padded, the y coordinates should be divided by ratio\n            box[:, 0] = box[:, 0] / ratio\n            box[:, 2] = box[:, 2] / ratio\n        return box\n    def filter_empty_box(text, boxes_image):\n        def extract_text(sentence):\n            # Use regular expression to find and extract the text and number\n            import re\n            if \" <seg>\" in sentence:\n                pattern = r\"<g_s>|<g_e> <seg>\"\n                cleaned_text = re.sub(pattern, '', sentence)\n                return cleaned_text\n            else:\n                cleaned_text = re.sub(r'<g_s> \\d+', '', sentence)\n                cleaned_text = re.sub(r' <g_e>', '', cleaned_text)\n                return cleaned_text\n        \n        has_gd = True if \"<seg>\" in text else False\n        if len(boxes_image) == 0:\n            return text, boxes_image\n        else:\n            if has_gd:\n                sub_texts = text.split(\" <seg>\")\n                sub_texts_filtered = []\n                boxes_image_filtered = []\n                for box_per_gd, text_per_gd in zip(boxes_image, sub_texts):\n                    text_per_gd += \" <seg>\"\n                    ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0)\n                    if len(ind_nonempty_box[0]) < box_per_gd.shape[0]:  # empty box encountered\n                        if len(ind_nonempty_box[0]) == 0:\n                            text_per_gd = \" \" + \" \".join(extract_text(text_per_gd).split())\n                            sub_texts_filtered.append(text_per_gd)  # box is desperated\n                            continue\n                        else:\n                            box_per_gd = box_per_gd[ind_nonempty_box]\n                            boxes_image_filtered.append(box_per_gd)\n                            sub_texts_filtered.append(text_per_gd)\n                    else:\n                        boxes_image_filtered.append(box_per_gd)\n                        sub_texts_filtered.append(text_per_gd)\n                sub_texts_filtered.append(sub_texts[-1])\n                text_filtered = \"\".join(sub_texts_filtered)\n                return text_filtered, boxes_image_filtered\n            else:\n                text_filtered = \" \".join(extract_text(text).split())\n                boxes_image_filtered = []\n                \n                return text_filtered, boxes_image_filtered\n    def debug(image, boxes, prefix):\n        import cv2\n        # image = cv2.imread(path_image)\n        def transform_str2numpy(boxes_str):\n            boxes_str = boxes_str.replace(\";\", \"];[\")\n            boxes_list = []\n            for box_str in boxes_str.split(\";\"):\n                box_list = [float(aa) for aa in box_str[1:-1].split(\",\")]\n                boxes_list.append(box_list)\n            return boxes_list\n        boxes_ = []\n        for box in boxes: boxes_.extend(transform_str2numpy(box))\n        boxes = torch.tensor(boxes_)\n        image = image[..., ::-1]\n        image = np.ascontiguousarray(image, dtype=np.uint8)\n        height,width,_ = image.shape\n        for box in boxes:\n            box = (box.cpu() * torch.tensor([width, height, width, height])).int().squeeze()\n            box = box.tolist()\n            image = cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0))\n        cv2.imwrite(f\"{prefix}_debug.jpg\", image)\n    \n    datas = load_benchmark(dir_image, path_benchmarks)[:20] #! use first 20 samples for debug.\n    data_mapper = evaluator.data_mapper\n    device = evaluator.model.device\n    outputs = []\n    for data in tqdm(datas):\n        input_data = data_mapper(data)[0]\n        for key, value in input_data.items():\n            if isinstance(value, torch.Tensor):\n                input_data[key] = value.to(device)\n        input_data[\"matching_threshold\"] = matching_threshold\n        text, boxes = evaluator.evaluate_sample([input_data])\n        text, boxes = filter_empty_box(text, boxes)\n        boxes = [unresize_box(bb.detach().cpu(), input_data[\"width\"], input_data[\"height\"], 1024) for bb in boxes]\n        output = formatting(text, boxes, input_data[\"question_id\"])\n        # from ipdb import set_trace; set_trace()\n        # debug(cv2.imread(input_data[\"file_name\"]), output[\"boxes\"], prefix=str(input_data[\"question_id\"]))\n        outputs.append(output)\n    return outputs\n\ndef evaluate(args=None):\n    evaluator = Evaluator_MM(\n        model_path=args.model_path,\n        path_vision_model_cfg=args.vision_model_cfg,\n    )\n    results = evaluate_(args.path_benchmark, dir_image=args.image_root, evaluator=evaluator, matching_threshold=args.matching_threshold)\n    return results\n    \nif __name__ == \"__main__\":\n    import argparse\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--model_path\", type=str, default=\"xx\")\n    args.add_argument(\"--vision_model_cfg\", type=str, default=\"xx\")\n    args.add_argument(\"--matching_threshold\", type=float, default=0.2)\n    args.add_argument(\"--path_benchmark\", default=\"./dataset/qa1000_questions.jsonl\")\n    args.add_argument(\"--image_root\", default=\"./dataset/coco/val2014\")\n    args = args.parse_args()\n    results = evaluate(args)\n    path_save = f\"./LLaVA_G_{args.path_benchmark.split('/')[-1].split('.')[0]}_t{args.matching_threshold}.jsonl\"  #! sync\n    print(\"Writing at: \", path_save)\n    save_jsonl_file(results, path_save)\n    # CUDA_VISIBLE_DEVICES=0 python llava/eval/LLaVA_G_Eval.py --model_path ./checkpoints/llava_stage2_new_joint_seg0.1_data_v3/checkpoint-8000 --vision_model_cfg configs/openseed/openseed_swint_lang_coco_instruct_coco_end_llava_bench.yaml --path_benchmark ./dataset/qa90_questions.jsonl --image_root ./dataset/coco/val2014 --matching_threshold 0.2"
  },
  {
    "path": "llava/eval/eval_gpt_review.py",
    "content": "import argparse\nimport json\nimport os\n\nimport openai\nimport tqdm\nimport ray\nimport time\n\nNUM_SECONDS_TO_SLEEP = 3\n\n@ray.remote(num_cpus=4)\ndef get_eval(content: str, max_tokens: int):\n    while True:\n        try:\n            response = openai.ChatCompletion.create(\n                model='gpt-4',\n                messages=[{\n                    'role': 'system',\n                    'content': 'You are a helpful and precise assistant for checking the quality of the answer.'\n                }, {\n                    'role': 'user',\n                    'content': content,\n                }],\n                temperature=0.2,  # TODO: figure out which temperature is best for evaluation\n                max_tokens=max_tokens,\n            )\n            break\n        except openai.error.RateLimitError:\n            pass\n        except Exception as e:\n            print(e)\n        time.sleep(NUM_SECONDS_TO_SLEEP)\n\n    print('success!')\n    return response['choices'][0]['message']['content']\n\n\ndef parse_score(review):\n    try:\n        score_pair = review.split('\\n')[0]\n        score_pair = score_pair.replace(',', ' ')\n        sp = score_pair.split(' ')\n        if len(sp) == 2:\n            return [float(sp[0]), float(sp[1])]\n        else:\n            print('error', review)\n            return [-1, -1]\n    except Exception as e:\n        print(e)\n        print('error', review)\n        return [-1, -1]\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')\n    parser.add_argument('-q', '--question')\n    # parser.add_argument('-a', '--answer')\n    parser.add_argument('-a', '--answer-list', nargs='+', default=[])\n    parser.add_argument('-r', '--rule')\n    parser.add_argument('-o', '--output')\n    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')\n    args = parser.parse_args()\n\n    ray.init()\n\n    f_q = open(os.path.expanduser(args.question))\n    f_ans1 = open(os.path.expanduser(args.answer_list[0]))\n    f_ans2 = open(os.path.expanduser(args.answer_list[1]))\n    rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))\n\n    review_file = open(f'{args.output}', 'w')\n\n    js_list = []\n    handles = []\n    idx = 0\n    for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):\n        # if idx == 1:\n        #     break\n\n        ques = json.loads(ques_js)\n        ans1 = json.loads(ans1_js)\n        ans2 = json.loads(ans2_js)\n\n        category = json.loads(ques_js)['category']\n        if category in rule_dict:\n            rule = rule_dict[category]\n        else:\n            rule = rule_dict['default']\n        prompt = rule['prompt']\n        role = rule['role']\n        content = (f'[Question]\\n{ques[\"text\"]}\\n\\n'\n                   f'[{role} 1]\\n{ans1[\"text\"]}\\n\\n[End of {role} 1]\\n\\n'\n                   f'[{role} 2]\\n{ans2[\"text\"]}\\n\\n[End of {role} 2]\\n\\n'\n                   f'[System]\\n{prompt}\\n\\n')\n        js_list.append({\n            'id': idx+1,\n            'question_id': ques['question_id'],\n            'answer1_id': ans1['answer_id'],\n            'answer2_id': ans2['answer_id'],\n            'category': category})\n        idx += 1\n        handles.append(get_eval.remote(content, args.max_tokens))\n        # To avoid the rate limit set by OpenAI\n        time.sleep(NUM_SECONDS_TO_SLEEP)\n\n    reviews = ray.get(handles)\n    for idx, review in enumerate(reviews):\n        scores = parse_score(review)\n        js_list[idx]['content'] = review\n        js_list[idx]['tuple'] = scores\n        review_file.write(json.dumps(js_list[idx]) + '\\n')\n    review_file.close()\n"
  },
  {
    "path": "llava/eval/eval_gpt_review_bench.py",
    "content": "import argparse\nimport json\nimport os\n\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\nopenai.api_type = \"azure\"\nopenai.api_base = \"https://xdecoder.openai.azure.com/\"\nopenai.api_version = \"2023-03-15-preview\"\nos.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032'\nopenai.api_key = os.getenv(\"OPENAI_API_KEY\")\n\n\ndef get_eval(content: str, max_tokens: int):\n    while True:\n        try:\n            response = openai.ChatCompletion.create(\n                engine='gpt4a',\n                messages=[{\n                    'role': 'system',\n                    'content': 'You are a helpful and precise assistant for checking the quality of the answer.'\n                }, {\n                    'role': 'user',\n                    'content': content,\n                }],\n                temperature=0.2,  # TODO: figure out which temperature is best for evaluation\n                max_tokens=max_tokens,\n            )\n            break\n        except openai.error.RateLimitError:\n            pass\n        except Exception as e:\n            print(e)\n        time.sleep(NUM_SECONDS_TO_SLEEP)\n\n    return response['choices'][0]['message']['content']\n\n\ndef parse_score(review):\n    try:\n        score_pair = review.split('\\n')[0]\n        score_pair = score_pair.replace(',', ' ')\n        sp = score_pair.split(' ')\n        if len(sp) == 2:\n            return [float(sp[0]), float(sp[1])]\n        else:\n            print('error', review)\n            return [-1, -1]\n    except Exception as e:\n        print(e)\n        print('error', review)\n        return [-1, -1]\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')\n    parser.add_argument('-q', '--question')\n    parser.add_argument('-c', '--context')\n    parser.add_argument('-a', '--answer-list', nargs='+', default=[])\n    parser.add_argument('-r', '--rule')\n    parser.add_argument('-o', '--output')\n    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')\n    args = parser.parse_args()\n\n    f_q = open(os.path.expanduser(args.question))\n    f_ans1 = open(os.path.expanduser(args.answer_list[0]))\n    f_ans2 = open(os.path.expanduser(args.answer_list[1]))\n    rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))\n\n    if os.path.isfile(os.path.expanduser(args.output)):\n        cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]\n    else:\n        cur_reviews = []\n\n    review_file = open(f'{args.output}', 'a')\n\n    context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]\n    image_to_context = {context['image']: context for context in context_list}\n\n    handles = []\n    idx = 0\n    for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):\n        ques = json.loads(ques_js)\n        ans1 = json.loads(ans1_js)\n        ans2 = json.loads(ans2_js)\n\n        inst = image_to_context[ques['image']]\n        cap_str = '\\n'.join(inst['caption'])\n\n        category = 'llava_bench_' + json.loads(ques_js)['category']\n        if category in rule_dict:\n            rule = rule_dict[category]\n        else:\n            assert False, f\"Visual QA category not found in rule file: {category}.\"\n        prompt = rule['prompt']\n        role = rule['role']\n        content = (f'[Context]\\n{cap_str}\\n\\n'\n                   f'[Question]\\n{ques[\"text\"]}\\n\\n'\n                   f'[{role} 1]\\n{ans1[\"text\"]}\\n\\n[End of {role} 1]\\n\\n'\n                   f'[{role} 2]\\n{ans2[\"text\"]}\\n\\n[End of {role} 2]\\n\\n'\n                   f'[System]\\n{prompt}\\n\\n')\n        cur_js = {\n            'id': idx+1,\n            'question_id': ques['question_id'],\n            'answer1_id': ans1.get('answer_id', ans1['question_id']),\n            'answer2_id': ans2.get('answer_id', ans2['answer_id']),\n            'category': category\n        }\n        if idx >= len(cur_reviews):\n            review = get_eval(content, args.max_tokens)\n            scores = parse_score(review)\n            cur_js['content'] = review\n            cur_js['tuple'] = scores\n            review_file.write(json.dumps(cur_js) + '\\n')\n            review_file.flush()\n        else:\n            print(f'Skipping {idx} as we already have it.')\n        idx += 1\n        print(idx)\n    review_file.close()\n"
  },
  {
    "path": "llava/eval/eval_gpt_review_visual.py",
    "content": "import argparse\nimport json\nimport os\n\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\nopenai.api_type = \"azure\"\nopenai.api_base = \"https://xdecoder.openai.azure.com/\"\nopenai.api_version = \"2023-03-15-preview\"\nos.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032'\nopenai.api_key = os.getenv(\"OPENAI_API_KEY\")\n\ndef get_eval(content: str, max_tokens: int):\n    while True:\n        try:\n            response = openai.ChatCompletion.create(\n                engine='gpt4a',\n                messages=[{\n                    'role': 'system',\n                    'content': 'You are a helpful and precise assistant for checking the quality of the answer.'\n                }, {\n                    'role': 'user',\n                    'content': content,\n                }],\n                temperature=0.2,  # TODO: figure out which temperature is best for evaluation\n                max_tokens=max_tokens,\n            )\n            break\n        except openai.error.RateLimitError:\n            pass\n        except Exception as e:\n            print(e)\n        time.sleep(NUM_SECONDS_TO_SLEEP)\n\n    return response['choices'][0]['message']['content']\n\n\ndef parse_score(review):\n    try:\n        score_pair = review.split('\\n')[0]\n        score_pair = score_pair.replace(',', ' ')\n        sp = score_pair.split(' ')\n        if len(sp) == 2:\n            return [float(sp[0]), float(sp[1])]\n        else:\n            print('error', review)\n            return [-1, -1]\n    except Exception as e:\n        print(e)\n        print('error', review)\n        return [-1, -1]\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')\n    parser.add_argument('-q', '--question')\n    parser.add_argument('-c', '--context')\n    parser.add_argument('-a', '--answer-list', nargs='+', default=[])\n    parser.add_argument('-r', '--rule')\n    parser.add_argument('-o', '--output')\n    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')\n    args = parser.parse_args()\n\n    f_q = open(os.path.expanduser(args.question))\n    f_ans1 = open(os.path.expanduser(args.answer_list[0]))\n    f_ans2 = open(os.path.expanduser(args.answer_list[1]))\n    rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))\n\n    if os.path.isfile(os.path.expanduser(args.output)):\n        cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]\n    else:\n        cur_reviews = []\n\n    review_file = open(f'{args.output}', 'a')\n\n    context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]\n    image_to_context = {context['image']: context for context in context_list}\n\n    handles = []\n    idx = 0\n    for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):\n        ques = json.loads(ques_js)\n        ans1 = json.loads(ans1_js)\n        ans2 = json.loads(ans2_js)\n\n        inst = image_to_context[ques['image']]\n        cap_str = '\\n'.join(inst['captions'])\n        box_str = '\\n'.join([f'{instance[\"category\"]}: {instance[\"bbox\"]}' for instance in inst['instances']])\n\n        category = json.loads(ques_js)['category']\n        if category in rule_dict:\n            rule = rule_dict[category]\n        else:\n            assert False, f\"Visual QA category not found in rule file: {category}.\"\n        prompt = rule['prompt']\n        role = rule['role']\n        content = (f'[Context]\\n{cap_str}\\n\\n{box_str}\\n\\n'\n                   f'[Question]\\n{ques[\"text\"]}\\n\\n'\n                   f'[{role} 1]\\n{ans1[\"text\"]}\\n\\n[End of {role} 1]\\n\\n'\n                   f'[{role} 2]\\n{ans2[\"text\"]}\\n\\n[End of {role} 2]\\n\\n'\n                   f'[System]\\n{prompt}\\n\\n')\n        cur_js = {\n            'id': idx+1,\n            'question_id': ques['question_id'],\n            'answer1_id': ans1.get('answer_id', ans1['question_id']),\n            'answer2_id': ans2.get('answer_id', ans2['answer_id']) if 'answer_id' in ans2 else ans2['question_id'],\n            'category': category\n        }\n        if idx >= len(cur_reviews):\n            review = get_eval(content, args.max_tokens)\n            scores = parse_score(review)\n            cur_js['content'] = review\n            cur_js['tuple'] = scores\n            review_file.write(json.dumps(cur_js) + '\\n')\n            review_file.flush()\n        else:\n            print(f'Skipping {idx} as we already have it.')\n        idx += 1\n        print(idx)\n    review_file.close()\n"
  },
  {
    "path": "llava/eval/eval_gpt_review_visual2.py",
    "content": "import argparse\nimport json\nimport os\n\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\n\nos.environ['OPENAI_API_KEY']='233c45550c614b72b8f3c9309efecf06'\nopenai.api_type = \"azure\"\nopenai.api_base = 'https://azureopenaifiahmedeastus.openai.azure.com/'\nopenai.api_version = '2023-03-15-preview'\nopenai.api_key = \"233c45550c614b72b8f3c9309efecf06\"\ndef get_eval(content: str, max_tokens: int):\n    while True:\n        try:\n            response = openai.ChatCompletion.create(\n                engine='gpt-4-32k-0314',\n                messages=[{\n                    'role': 'system',\n                    'content': 'You are a helpful and precise assistant for checking the quality of the answer.'\n                }, {\n                    'role': 'user',\n                    'content': content,\n                }],\n                temperature=0.2,  # TODO: figure out which temperature is best for evaluation\n                max_tokens=max_tokens,\n            )\n            break\n        except openai.error.RateLimitError:\n            pass\n        except Exception as e:\n            print(e)\n        time.sleep(NUM_SECONDS_TO_SLEEP)\n\n    return response['choices'][0]['message']['content']\n\n\ndef parse_score(review):\n    try:\n        score_pair = review.split('\\n')[0]\n        score_pair = score_pair.replace(',', ' ')\n        sp = score_pair.split(' ')\n        if len(sp) == 2:\n            return [float(sp[0]), float(sp[1])]\n        else:\n            print('error', review)\n            return [-1, -1]\n    except Exception as e:\n        print(e)\n        print('error', review)\n        return [-1, -1]\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')\n    parser.add_argument('-q', '--question')\n    parser.add_argument('-c', '--context')\n    parser.add_argument('-a', '--answer-list', nargs='+', default=[])\n    parser.add_argument('-r', '--rule')\n    parser.add_argument('-o', '--output')\n    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')\n    args = parser.parse_args()\n\n    f_q = open(os.path.expanduser(args.question))\n    f_ans1 = open(os.path.expanduser(args.answer_list[0]))\n    f_ans2 = open(os.path.expanduser(args.answer_list[1]))\n    rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))\n\n    if os.path.isfile(os.path.expanduser(args.output)):\n        cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]\n    else:\n        cur_reviews = []\n\n    review_file = open(f'{args.output}', 'a')\n\n    context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]\n    image_to_context = {context['image']: context for context in context_list}\n\n    handles = []\n    idx = 0\n    for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):\n        ques = json.loads(ques_js)\n        ans1 = json.loads(ans1_js)\n        ans2 = json.loads(ans2_js)\n\n        inst = image_to_context[ques['image']]\n        cap_str = '\\n'.join(inst['captions'])\n        box_str = '\\n'.join([f'{instance[\"category\"]}: {instance[\"bbox\"]}' for instance in inst['instances']])\n\n        category = json.loads(ques_js)['category']\n        if category in rule_dict:\n            rule = rule_dict[category]\n        else:\n            assert False, f\"Visual QA category not found in rule file: {category}.\"\n        prompt = rule['prompt']\n        role = rule['role']\n        content = (f'[Context]\\n{cap_str}\\n\\n{box_str}\\n\\n'\n                   f'[Question]\\n{ques[\"text\"]}\\n\\n'\n                   f'[{role} 1]\\n{ans1[\"text\"]}\\n\\n[End of {role} 1]\\n\\n'\n                   f'[{role} 2]\\n{ans2[\"text\"]}\\n\\n[End of {role} 2]\\n\\n'\n                   f'[System]\\n{prompt}\\n\\n')\n        cur_js = {\n            'id': idx+1,\n            'question_id': ques['question_id'],\n            'answer1_id': ans1.get('answer_id', ans1['question_id']),\n            'answer2_id': ans2.get('answer_id', ans2['answer_id']),\n            'category': category\n        }\n        if idx >= len(cur_reviews):\n            review = get_eval(content, args.max_tokens)\n            scores = parse_score(review)\n            cur_js['content'] = review\n            cur_js['tuple'] = scores\n            review_file.write(json.dumps(cur_js) + '\\n')\n            review_file.flush()\n        else:\n            print(f'Skipping {idx} as we already have it.')\n        idx += 1\n        print(idx)\n    review_file.close()\n"
  },
  {
    "path": "llava/eval/eval_science_qa.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nimport random\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--base-dir', type=str)\n    parser.add_argument('--result-file', type=str)\n    parser.add_argument('--output-file', type=str)\n    parser.add_argument('--output-result', type=str)\n    parser.add_argument('--split', type=str, default='test')\n    parser.add_argument('--options', type=list, default=[\"A\", \"B\", \"C\", \"D\", \"E\"])\n    return parser.parse_args()\n\n\ndef convert_caps(results):\n    fakecaps = []\n    for result in results:\n        image_id = result['question_id']\n        caption = result['text']\n        fakecaps.append({\"image_id\": int(image_id), \"caption\": caption})\n    return fakecaps\n\n\ndef get_pred_idx(prediction, choices, options):\n    \"\"\"\n    Get the index (e.g. 2) from the prediction (e.g. 'C')\n    \"\"\"\n    if prediction in options[:len(choices)]:\n        return options.index(prediction)\n    else:\n        return random.choice(range(len(choices)))\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    base_dir = args.base_dir\n    split_indices = json.load(open(os.path.join(base_dir, \"pid_splits.json\")))[args.split]\n    problems = json.load(open(os.path.join(base_dir, \"problems.json\")))\n    predictions = [json.loads(line) for line in open(args.result_file)]\n    predictions = {pred['question_id']: pred for pred in predictions}\n    split_problems = {idx: problems[idx] for idx in split_indices}\n\n    results = {'correct': [], 'incorrect': []}\n    sqa_results = {}\n    sqa_results['acc'] = None\n    sqa_results['correct'] = None\n    sqa_results['count'] = None\n    sqa_results['results'] = {}\n    sqa_results['outputs'] = {}\n\n    for prob_id, prob in split_problems.items():\n        if prob_id not in predictions:\n            continue\n        pred = predictions[prob_id]\n        pred_text = pred['text']\n\n        pattern = re.compile(r'The answer is ([A-Z]).')\n        res = pattern.findall(pred_text)\n        if len(res) == 1:\n            answer = res[0]  # 'A', 'B', ...\n        else:\n            answer = \"FAILED\"\n\n        pred_idx = get_pred_idx(answer, prob['choices'], args.options)\n\n        analysis = {\n            'question_id': prob_id,\n            'parsed_ans': answer,\n            'ground_truth': args.options[prob['answer']],\n            'question': pred['prompt'],\n            'pred': pred_text,\n            'is_multimodal': '<image>' in pred['prompt'],\n        }\n\n        sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)\n        sqa_results['outputs'][prob_id] = pred_text\n\n        if pred_idx == prob['answer']:\n            results['correct'].append(analysis)\n        else:\n            results['incorrect'].append(analysis)\n\n    correct = len(results['correct'])\n    total = len(results['correct']) + len(results['incorrect'])\n    print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')\n\n    sqa_results['acc'] = correct / total * 100\n    sqa_results['correct'] = correct\n    sqa_results['count'] = total\n\n    with open(args.output_file, 'w') as f:\n        json.dump(results, f, indent=2)\n    with open(args.output_result, 'w') as f:\n        json.dump(sqa_results, f, indent=2)\n"
  },
  {
    "path": "llava/eval/eval_science_qa_gpt4.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--base-dir', type=str)\n    parser.add_argument('--gpt4-result', type=str)\n    parser.add_argument('--our-result', type=str)\n    parser.add_argument('--split', type=str, default='test')\n    parser.add_argument('--options', type=list, default=[\"A\", \"B\", \"C\", \"D\", \"E\"])\n    return parser.parse_args()\n\n\ndef convert_caps(results):\n    fakecaps = []\n    for result in results:\n        image_id = result['question_id']\n        caption = result['text']\n        fakecaps.append({\"image_id\": int(image_id), \"caption\": caption})\n    return fakecaps\n\n\ndef get_pred_idx(prediction, choices, options):\n    \"\"\"\n    Get the index (e.g. 2) from the prediction (e.g. 'C')\n    \"\"\"\n    if prediction in options[:len(choices)]:\n        return options.index(prediction)\n    else:\n        return random.choice(range(len(choices)))\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    base_dir = args.base_dir\n    split_indices = json.load(open(os.path.join(base_dir, \"pid_splits.json\")))[args.split]\n    problems = json.load(open(os.path.join(base_dir, \"problems.json\")))\n    our_predictions = [json.loads(line) for line in open(args.our_result)]\n    our_predictions = {pred['question_id']: pred for pred in our_predictions}\n    split_problems = {idx: problems[idx] for idx in split_indices}\n\n    gpt4_predictions = json.load(open(args.gpt4_result))['outputs']\n\n    results = defaultdict(lambda: 0)\n\n    for prob_id, prob in split_problems.items():\n        if prob_id not in our_predictions:\n            continue\n        if prob_id not in gpt4_predictions:\n            continue\n        our_pred = our_predictions[prob_id]['text']\n        gpt4_pred = gpt4_predictions[prob_id]\n\n        pattern = re.compile(r'The answer is ([A-Z]).')\n        our_res = pattern.findall(our_pred)\n        if len(our_res) == 1:\n            our_answer = our_res[0]  # 'A', 'B', ...\n        else:\n            our_answer = \"FAILED\"\n        gpt4_res = pattern.findall(gpt4_pred)\n        if len(gpt4_res) == 1:\n            gpt4_answer = gpt4_res[0]  # 'A', 'B', ...\n        else:\n            gpt4_answer = \"FAILED\"\n\n        our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)\n        gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)\n\n        if gpt4_answer == 'FAILED':\n            results['gpt4_failed'] += 1\n            # continue\n            gpt4_pred_idx = our_pred_idx\n            # if our_pred_idx != prob['answer']:\n            #     print(our_predictions[prob_id]['prompt'])\n            #     print('-----------------')\n            #     print(f'LECTURE: {prob[\"lecture\"]}')\n            #     print(f'SOLUTION: {prob[\"solution\"]}')\n            #     print('=====================')\n        else:\n            # continue\n            pass\n        # gpt4_pred_idx = our_pred_idx\n\n        if gpt4_pred_idx == prob['answer']:\n            results['correct'] += 1\n        else:\n            results['incorrect'] += 1\n\n\n        if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:\n            results['correct_upperbound'] += 1\n\n    correct = results['correct']\n    total = results['correct'] + results['incorrect']\n    print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')\n    print(f'Total: {total}, Correct (upper): {results[\"correct_upperbound\"]}, Accuracy: {results[\"correct_upperbound\"] / total * 100:.2f}%')\n    print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results[\"gpt4_failed\"]}, Percentage: {results[\"gpt4_failed\"] / total * 100:.2f}%')\n\n"
  },
  {
    "path": "llava/eval/eval_science_qa_gpt4_requery.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--base-dir', type=str)\n    parser.add_argument('--gpt4-result', type=str)\n    parser.add_argument('--requery-result', type=str)\n    parser.add_argument('--our-result', type=str)\n    parser.add_argument('--output-result', type=str)\n    parser.add_argument('--split', type=str, default='test')\n    parser.add_argument('--options', type=list, default=[\"A\", \"B\", \"C\", \"D\", \"E\"])\n    return parser.parse_args()\n\n\ndef convert_caps(results):\n    fakecaps = []\n    for result in results:\n        image_id = result['question_id']\n        caption = result['text']\n        fakecaps.append({\"image_id\": int(image_id), \"caption\": caption})\n    return fakecaps\n\n\ndef get_pred_idx(prediction, choices, options):\n    \"\"\"\n    Get the index (e.g. 2) from the prediction (e.g. 'C')\n    \"\"\"\n    if prediction in options[:len(choices)]:\n        return options.index(prediction)\n    else:\n        return random.choice(range(len(choices)))\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    base_dir = args.base_dir\n    split_indices = json.load(open(os.path.join(base_dir, \"pid_splits.json\")))[args.split]\n    problems = json.load(open(os.path.join(base_dir, \"problems.json\")))\n    our_predictions = [json.loads(line) for line in open(args.our_result)]\n    our_predictions = {pred['question_id']: pred for pred in our_predictions}\n    split_problems = {idx: problems[idx] for idx in split_indices}\n\n    requery_predictions = [json.loads(line) for line in open(args.requery_result)]\n    requery_predictions = {pred['question_id']: pred for pred in requery_predictions}\n\n    gpt4_predictions = json.load(open(args.gpt4_result))['outputs']\n\n    results = defaultdict(lambda: 0)\n\n    sqa_results = {}\n    sqa_results['acc'] = None\n    sqa_results['correct'] = None\n    sqa_results['count'] = None\n    sqa_results['results'] = {}\n    sqa_results['outputs'] = {}\n\n    for prob_id, prob in split_problems.items():\n        if prob_id not in our_predictions:\n            assert False\n        if prob_id not in gpt4_predictions:\n            assert False\n        our_pred = our_predictions[prob_id]['text']\n        gpt4_pred = gpt4_predictions[prob_id]\n        if prob_id not in requery_predictions:\n            results['missing_requery'] += 1\n            requery_pred = \"MISSING\"\n        else:\n            requery_pred = requery_predictions[prob_id]['text']\n\n        pattern = re.compile(r'The answer is ([A-Z]).')\n        our_res = pattern.findall(our_pred)\n        if len(our_res) == 1:\n            our_answer = our_res[0]  # 'A', 'B', ...\n        else:\n            our_answer = \"FAILED\"\n\n        requery_res = pattern.findall(requery_pred)\n        if len(requery_res) == 1:\n            requery_answer = requery_res[0]  # 'A', 'B', ...\n        else:\n            requery_answer = \"FAILED\"\n\n        gpt4_res = pattern.findall(gpt4_pred)\n        if len(gpt4_res) == 1:\n            gpt4_answer = gpt4_res[0]  # 'A', 'B', ...\n        else:\n            gpt4_answer = \"FAILED\"\n\n        our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)\n        gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)\n        requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)\n\n        results['total'] += 1\n\n        if gpt4_answer == 'FAILED':\n            results['gpt4_failed'] += 1\n            if gpt4_pred_idx == prob['answer']:\n                results['gpt4_correct'] += 1\n            if our_pred_idx == prob['answer']:\n                results['gpt4_ourvisual_correct'] += 1\n        elif gpt4_pred_idx == prob['answer']:\n            results['gpt4_correct'] += 1\n            results['gpt4_ourvisual_correct'] += 1\n\n        if our_pred_idx == prob['answer']:\n            results['our_correct'] += 1\n\n        if requery_answer == 'FAILED':\n            sqa_results['results'][prob_id] = our_pred_idx\n            if our_pred_idx == prob['answer']:\n                results['requery_correct'] += 1\n        else:\n            sqa_results['results'][prob_id] = requery_pred_idx\n            if requery_pred_idx == prob['answer']:\n                results['requery_correct'] += 1\n            else:\n                print(f\"\"\"\nQuestion ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}\nOur ({our_answer}): {our_pred}\nGPT-4 ({gpt4_answer}): {gpt4_pred}\nRequery ({requery_answer}): {requery_pred}\nprint(\"=====================================\")\n\"\"\")\n\n        if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:\n            results['correct_upperbound'] += 1\n\n    total = results['total']\n    print(f'Total: {total}, Our-Correct: {results[\"our_correct\"]}, Accuracy: {results[\"our_correct\"] / total * 100:.2f}%')\n    print(f'Total: {total}, GPT-4-Correct: {results[\"gpt4_correct\"]}, Accuracy: {results[\"gpt4_correct\"] / total * 100:.2f}%')\n    print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results[\"gpt4_failed\"]}, Percentage: {results[\"gpt4_failed\"] / total * 100:.2f}%')\n    print(f'Total: {total}, GPT-4-OursVisual-Correct: {results[\"gpt4_ourvisual_correct\"]}, Accuracy: {results[\"gpt4_ourvisual_correct\"] / total * 100:.2f}%')\n    print(f'Total: {total}, Requery-Correct: {results[\"requery_correct\"]}, Accuracy: {results[\"requery_correct\"] / total * 100:.2f}%')\n    print(f'Total: {total}, Correct upper: {results[\"correct_upperbound\"]}, Accuracy: {results[\"correct_upperbound\"] / total * 100:.2f}%')\n\n    sqa_results['acc'] = results[\"requery_correct\"] / total * 100\n    sqa_results['correct'] = results[\"requery_correct\"]\n    sqa_results['count'] = total\n\n    with open(args.output_result, 'w') as f:\n        json.dump(sqa_results, f, indent=2)\n\n"
  },
  {
    "path": "llava/eval/generate_webpage_data_from_table.py",
    "content": "\"\"\"Generate json file for webpage.\"\"\"\nimport json\nimport os\nimport re\n\n# models = ['llama', 'alpaca', 'gpt35', 'bard']\nmodels = ['vicuna']\n\n\ndef read_jsonl(path: str, key: str=None):\n    data = []\n    with open(os.path.expanduser(path)) as f:\n        for line in f:\n            if not line:\n                continue\n            data.append(json.loads(line))\n    if key is not None:\n        data.sort(key=lambda x: x[key])\n        data = {item[key]: item for item in data}\n    return data\n\n\ndef trim_hanging_lines(s: str, n: int) -> str:\n    s = s.strip()\n    for _ in range(n):\n        s = s.split('\\n', 1)[1].strip()\n    return s\n\n\nif __name__ == '__main__':\n    questions = read_jsonl('table/question.jsonl', key='question_id')\n\n    # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')\n    # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')\n    # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')\n    # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')\n    vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')\n    ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')\n\n    review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')\n    # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')\n    # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')\n    # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')\n    # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')\n\n    records = []\n    for qid in questions.keys():\n        r = {\n            'id': qid,\n            'category': questions[qid]['category'],\n            'question': questions[qid]['text'],\n            'answers': {\n                # 'alpaca': alpaca_answers[qid]['text'],\n                # 'llama': llama_answers[qid]['text'],\n                # 'bard': bard_answers[qid]['text'],\n                # 'gpt35': gpt35_answers[qid]['text'],\n                'vicuna': vicuna_answers[qid]['text'],\n                'ours': ours_answers[qid]['text'],\n            },\n            'evaluations': {\n                # 'alpaca': review_alpaca[qid]['text'],\n                # 'llama': review_llama[qid]['text'],\n                # 'bard': review_bard[qid]['text'],\n                'vicuna': review_vicuna[qid]['content'],\n                # 'gpt35': review_gpt35[qid]['text'],\n            },\n            'scores': {\n                'vicuna': review_vicuna[qid]['tuple'],\n                # 'alpaca': review_alpaca[qid]['score'],\n                # 'llama': review_llama[qid]['score'],\n                # 'bard': review_bard[qid]['score'],\n                # 'gpt35': review_gpt35[qid]['score'],\n            },\n        }\n\n        # cleanup data\n        cleaned_evals = {}\n        for k, v in r['evaluations'].items():\n            v = v.strip()\n            lines = v.split('\\n')\n            # trim the first line if it's a pair of numbers\n            if re.match(r'\\d+[, ]+\\d+', lines[0]):\n                lines = lines[1:]\n            v = '\\n'.join(lines)\n            cleaned_evals[k] = v.replace('Assistant 1', \"**Assistant 1**\").replace('Assistant 2', '**Assistant 2**')\n\n        r['evaluations'] = cleaned_evals\n        records.append(r)\n\n    # Reorder the records, this is optional\n    for r in records:\n        if r['id'] <= 20:\n            r['id'] += 60\n        else:\n            r['id'] -= 20\n    for r in records:\n        if r['id'] <= 50:\n            r['id'] += 10\n        elif 50 < r['id'] <= 60:\n            r['id'] -= 50\n    for r in records:\n        if r['id'] == 7:\n            r['id'] = 1\n        elif r['id'] < 7:\n            r['id'] += 1 \n\n    records.sort(key=lambda x: x['id'])\n\n    # Write to file\n    with open('webpage/data.json', 'w') as f:\n        json.dump({'questions': records, 'models': models}, f, indent=2)\n"
  },
  {
    "path": "llava/eval/llava_mapper.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py\nimport copy\nimport logging\n\nimport numpy as np\nimport torch\nimport PIL.Image as Image\nfrom detectron2.data import detection_utils as utils\nfrom detectron2.data import transforms as T\nfrom detectron2.data.transforms import TransformGen\nfrom detectron2.structures import BitMasks, Instances\n\nfrom pycocotools import mask as coco_mask\n\nfrom llava.model.openseed.utils import configurable\nfrom detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes\nfrom llava import conversation as conversation_lib\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n# from llava.train.train_hao_seg_flickr import ,preprocess\n__all__ = [\"COCOInstanceNewBaselineDatasetMapper\"]\n\n\ndef convert_coco_poly_to_mask(segmentations, height, width):\n    masks = []\n    for polygons in segmentations:\n        rles = coco_mask.frPyObjects(polygons, height, width)\n        mask = coco_mask.decode(rles)\n        if len(mask.shape) < 3:\n            mask = mask[..., None]\n        mask = torch.as_tensor(mask, dtype=torch.uint8)\n        mask = mask.any(dim=2)\n        masks.append(mask)\n    if masks:\n        masks = torch.stack(masks, dim=0)\n    else:\n        masks = torch.zeros((0, height, width), dtype=torch.uint8)\n    return masks\n\ndef preprocess_multimodal(\n    sources,\n    is_multimodal\n):\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if False:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\ndef build_transform_gen(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    if is_train:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n    else:\n        cfg_input = cfg['INPUT']\n        image_size = cfg_input['IMAGE_SIZE']\n        min_scale = cfg_input['MIN_SCALE']\n        max_scale = cfg_input['MAX_SCALE']\n\n        augmentation = []\n\n        # if cfg_input['RANDOM_FLIP'] != \"none\":\n        #     augmentation.append(\n        #         T.RandomFlip(\n        #             horizontal=cfg_input['RANDOM_FLIP'] == \"horizontal\",\n        #             vertical=cfg_input['RANDOM_FLIP'] == \"vertical\",\n        #         )\n        #     )\n\n        augmentation.extend([\n            T.ResizeScale(\n                min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size\n            ),\n            T.FixedSizeCrop(crop_size=(image_size, image_size)),\n        ])\n\n    return augmentation\n\n\n# This is specifically designed for the COCO dataset.\nclass COCOInstanceNewBaselineDatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by MaskFormer.\n\n    This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies geometric transforms to the image and annotation\n    3. Find and applies suitable cropping to the image and annotation\n    4. Prepare image and annotation to Tensors\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train=True,\n        *,\n        tfm_gens,\n        image_format,\n        tokenizer,\n        image_processor,\n        preprocess,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            is_train: for training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            tfm_gens: data augmentation\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n        \"\"\"\n        self.tfm_gens = tfm_gens\n        logging.getLogger(__name__).info(\n            \"[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}\".format(str(self.tfm_gens))\n        )\n\n        self.img_format = image_format\n        self.is_train = is_train\n        self.tokenizer = tokenizer\n        self.processor = image_processor\n        self.preprocess = preprocess\n    \n    @classmethod\n    def from_config(cls, cfg, is_train=True,tokenizer=None,image_processor=None,preprocess=None):\n        # Build augmentation\n        tfm_gens = build_transform_gen(cfg, is_train)\n\n        ret = {\n            \"is_train\": is_train,\n            \"tfm_gens\": tfm_gens,\n            \"image_format\": cfg['INPUT']['FORMAT'],\n            \"tokenizer\": tokenizer,\n            \"image_processor\": image_processor,\n            \"preprocess\": preprocess,\n        }\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.img_format)\n        utils.check_image_size(dataset_dict, image)\n\n        #########llava image processing\n\n        image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n        dataset_dict[\"image_clip\"] = image_clip\n\n        ##################\n\n        # TODO: get padding mask\n        # by feeding a \"segmentation mask\" to the same transforms\n        padding_mask = np.ones(image.shape[:2])\n\n        image, transforms = T.apply_transform_gens(self.tfm_gens, image)\n        dataset_dict[\"image_ori\"]=image\n        # the crop transformation has default padding value 0 for segmentation\n        padding_mask = transforms.apply_segmentation(padding_mask)\n        padding_mask = ~ padding_mask.astype(bool)\n\n        image_shape = image.shape[:2]  # h, w\n\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        dataset_dict[\"padding_mask\"] = torch.as_tensor(np.ascontiguousarray(padding_mask))\n        num_conversations = len(dataset_dict['conversations'])\n        rd = np.random.choice(num_conversations)\n        # selected_conversation, grounding_list = dataset_dict['conversations'][rd]\n        # dataset_dict['conversation'] = [selected_conversation]\n        selected_conversation = [aa[0] for aa in dataset_dict['conversations']]\n        dataset_dict['conversation'] = selected_conversation\n        sources = preprocess_multimodal(\n            copy.deepcopy(dataset_dict['conversation']),\n            True)  #! Debug here\n        # sources = copy.deepcopy(dataset_dict['conversation'])\n        data_dict_conversation = self.preprocess(\n            sources,\n            self.tokenizer,\n            has_image=True)\n        data_dict_conversation = dict(input_ids=data_dict_conversation[\"input_ids\"][0],\n                                      labels=data_dict_conversation[\"labels\"][0])\n        dataset_dict.update(data_dict_conversation)\n        dataset_dict['tokenizer'] = self.tokenizer\n        num_segs = 1 # sum([conv['value'].count('<seg>') for conv in selected_conversation])\n        # grounding_list=\n        if \"grounding_info\" in dataset_dict and len(dataset_dict['grounding_info'])>0:\n            anno_id2id=dict()\n            for id,obj in enumerate(dataset_dict['grounding_info']):\n                obj[\"bbox_mode\"] = BoxMode.XYWH_ABS\n                anno_id2id[obj['id']]=id\n            id2class=[[] for _ in range(len(dataset_dict['grounding_info']))]\n\n            annos = [\n                utils.transform_instance_annotations(obj, transforms, image_shape)\n                for obj in dataset_dict[\"grounding_info\"]\n            ]\n            # assert  \"segmentation\" in annos[0]\n            instances = utils.annotations_to_instances(annos, image_shape,mask_format=\"bitmask\")\n\n            h, w = instances.image_size\n            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)\n            if hasattr(instances, 'gt_masks'):\n                gt_masks = instances.gt_masks\n                # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)\n                instances.gt_masks = gt_masks.tensor\n\n            if grounding_list is None:\n                dataset_dict['grounding']=False\n                grounding_mask=[False for _ in range(num_segs)]\n                dataset_dict['grounding_mask']=grounding_mask\n            else:\n                grounding_mask=[True if g is not None else False for g in grounding_list]\n                dataset_dict['grounding_mask']=grounding_mask\n                new_grounding_list=[g for g in grounding_list if g is not None]\n                if sum(grounding_mask)==0:\n                    dataset_dict['grounding']=False\n                else:\n                    dataset_dict['grounding']=True\n            if dataset_dict['grounding']:\n                # assert num_segs == len(grounding_list)\n                for grounding_id,grounding in enumerate(new_grounding_list):\n                    if grounding is not None:\n                        for annid in grounding:\n                            id2class[anno_id2id[annid]].append(grounding_id)\n\n                instances.gt_classes=id2class\n            dataset_dict[\"instances\"] = instances\n\n        else:\n            dataset_dict['grounding'] = False\n            grounding_mask = [False for _ in range(num_segs)]\n            dataset_dict['grounding_mask'] = grounding_mask\n\n        return [dataset_dict]\n"
  },
  {
    "path": "llava/eval/model_qa.py",
    "content": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.conversation import default_conversation\nfrom llava.utils import disable_torch_init\n\n\n# new stopping implementation\nclass KeywordsStoppingCriteria(StoppingCriteria):\n    def __init__(self, keywords, tokenizer, input_ids):\n        self.keywords = keywords\n        self.tokenizer = tokenizer\n        self.start_len = None\n        self.input_ids = input_ids\n\n    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        if self.start_len is None:\n            self.start_len = self.input_ids.shape[1]\n        else:\n            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]\n            for keyword in self.keywords:\n                if keyword in outputs:\n                    return True\n        return False\n\n\n@torch.inference_mode()\ndef eval_model(model_name, questions_file, answers_file):\n    # Model\n    disable_torch_init()\n    model_name = os.path.expanduser(model_name)\n    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n    model = AutoModelForCausalLM.from_pretrained(model_name,\n        torch_dtype=torch.float16).cuda()\n\n\n    ques_file = open(os.path.expanduser(questions_file), \"r\")\n    ans_file = open(os.path.expanduser(answers_file), \"w\")\n    for i, line in enumerate(tqdm(ques_file)):\n        idx = json.loads(line)[\"question_id\"]\n        qs = json.loads(line)[\"text\"]\n        cat = json.loads(line)[\"category\"]\n        conv = default_conversation.copy()\n        conv.append_message(conv.roles[0], qs)\n        prompt = conv.get_prompt()\n        inputs = tokenizer([prompt])\n        input_ids = torch.as_tensor(inputs.input_ids).cuda()\n        stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)\n        output_ids = model.generate(\n            input_ids,\n            do_sample=True,\n            use_cache=True,\n            temperature=0.7,\n            max_new_tokens=1024,\n            stopping_criteria=[stopping_criteria])\n        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]\n        try:\n            index = outputs.index(conv.sep, len(prompt))\n        except ValueError:\n            outputs += conv.sep\n            index = outputs.index(conv.sep, len(prompt))\n\n        outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()\n        ans_id = shortuuid.uuid()\n        ans_file.write(json.dumps({\"question_id\": idx,\n                                   \"text\": outputs,\n                                   \"answer_id\": ans_id,\n                                   \"model_id\": model_name,\n                                   \"metadata\": {}}) + \"\\n\")\n        ans_file.flush()\n    ans_file.close()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--question-file\", type=str, default=\"tables/question.jsonl\")\n    parser.add_argument(\"--answers-file\", type=str, default=\"answer.jsonl\")\n    args = parser.parse_args()\n\n    eval_model(args.model_name, args.question_file, args.answers_file)\n"
  },
  {
    "path": "llava/eval/model_vqa.py",
    "content": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom llava.conversation import conv_templates, SeparatorStyle\nfrom llava.model.builder import load_pretrained_model\nfrom llava.utils import disable_torch_init\nfrom llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n\nfrom PIL import Image\nimport math\n\n\ndef split_list(lst, n):\n    \"\"\"Split a list into n (roughly) equal-sized chunks\"\"\"\n    chunk_size = math.ceil(len(lst) / n)  # integer division\n    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]\n\n\ndef get_chunk(lst, n, k):\n    chunks = split_list(lst, n)\n    return chunks[k]\n\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n    model_path = os.path.expanduser(args.model_path)\n    model_name = get_model_name_from_path(model_path)\n    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)\n    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), \"r\")]\n    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)\n    answers_file = os.path.expanduser(args.answers_file)\n    os.makedirs(os.path.dirname(answers_file), exist_ok=True)\n    ans_file = open(answers_file, \"w\")\n    for line in tqdm(questions):\n        idx = line[\"question_id\"]\n        image_file = line[\"image\"]\n        qs = line[\"text\"]\n        cur_prompt = qs\n        if model.config.mm_use_im_start_end:\n            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs\n        else:\n            qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n\n        conv = conv_templates[args.conv_mode].copy()\n        conv.append_message(conv.roles[0], qs)\n        conv.append_message(conv.roles[1], None)\n        prompt = conv.get_prompt()\n\n        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n        try:\n            image = Image.open(os.path.join(args.image_folder, \"COCO_val2014_\"+image_file))\n        except Exception:\n            image = Image.open(os.path.join(args.image_folder, image_file))\n\n        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n\n        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n        keywords = [stop_str]\n        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n        # import pdb; pdb.set_trace()\n        with torch.inference_mode():\n            output_ids = model.generate(\n                input_ids,\n                images=image_tensor.unsqueeze(0).half().cuda(),\n                do_sample=True,\n                temperature=args.temperature,\n                top_p=args.top_p,\n                num_beams=args.num_beams,\n                # no_repeat_ngram_size=3,\n                max_new_tokens=2048,\n                use_cache=True)\n\n        input_token_len = input_ids.shape[1]\n        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n        if n_diff_input_output > 0:\n            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n        outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n        outputs = outputs.strip()\n        if outputs.endswith(stop_str):\n            outputs = outputs[:-len(stop_str)]\n        outputs = outputs.strip()\n\n        ans_id = shortuuid.uuid()\n        ans_file.write(json.dumps({\"question_id\": idx,\n                                   \"prompt\": cur_prompt,\n                                   \"text\": outputs,\n                                   \"answer_id\": ans_id,\n                                   \"model_id\": model_name,\n                                   \"metadata\": {}}) + \"\\n\")\n        ans_file.flush()\n    ans_file.close()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-path\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--model-base\", type=str, default=None)\n    parser.add_argument(\"--image-folder\", type=str, default=\"\")\n    parser.add_argument(\"--question-file\", type=str, default=\"tables/question.jsonl\")\n    parser.add_argument(\"--answers-file\", type=str, default=\"answer.jsonl\")\n    parser.add_argument(\"--conv-mode\", type=str, default=\"llava_v1\")\n    parser.add_argument(\"--num-chunks\", type=int, default=1)\n    parser.add_argument(\"--chunk-idx\", type=int, default=0)\n    parser.add_argument(\"--temperature\", type=float, default=0.2)\n    parser.add_argument(\"--top_p\", type=float, default=None)\n    parser.add_argument(\"--num_beams\", type=int, default=1)\n    args = parser.parse_args()\n\n    eval_model(args)\n"
  },
  {
    "path": "llava/eval/model_vqa_science.py",
    "content": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom llava.conversation import conv_templates, SeparatorStyle\nfrom llava.model.builder import load_pretrained_model\nfrom llava.utils import disable_torch_init\nfrom llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n\nfrom PIL import Image\nimport math\n\n\ndef split_list(lst, n):\n    \"\"\"Split a list into n (roughly) equal-sized chunks\"\"\"\n    chunk_size = math.ceil(len(lst) / n)  # integer division\n    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]\n\n\ndef get_chunk(lst, n, k):\n    chunks = split_list(lst, n)\n    return chunks[k]\n\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n    model_path = os.path.expanduser(args.model_path)\n    model_name = get_model_name_from_path(model_path)\n    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)\n\n    questions = json.load(open(os.path.expanduser(args.question_file), \"r\"))\n    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)\n    answers_file = os.path.expanduser(args.answers_file)\n    os.makedirs(os.path.dirname(answers_file), exist_ok=True)\n    ans_file = open(answers_file, \"w\")\n    for i, line in enumerate(tqdm(questions)):\n        idx = line[\"id\"]\n        question = line['conversations'][0]\n        gt_ans = line[\"conversations\"][1]\n        qs = question['value'].replace('<image>', '').strip()\n        cur_prompt = qs\n\n        if 'image' in line:\n            image_file = line[\"image\"]\n            image = Image.open(os.path.join(args.image_folder, image_file))\n            image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n            images = image_tensor.unsqueeze(0).half().cuda()\n            if getattr(model.config, 'mm_use_im_start_end', False):\n                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs\n            else:\n                qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n            cur_prompt = '<image>' + '\\n' + cur_prompt\n        else:\n            images = None\n\n        conv = conv_templates[args.conv_mode].copy()\n        conv.append_message(conv.roles[0], qs)\n        conv.append_message(conv.roles[1], None)\n        prompt = conv.get_prompt()\n\n        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n\n        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n        keywords = [stop_str]\n        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n\n        with torch.inference_mode():\n            output_ids = model.generate(\n                input_ids,\n                images=images,\n                do_sample=True,\n                temperature=0.2,\n                max_new_tokens=1024,\n                use_cache=True,\n                stopping_criteria=[stopping_criteria])\n\n        input_token_len = input_ids.shape[1]\n        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n        if n_diff_input_output > 0:\n            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n        outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n        outputs = outputs.strip()\n        if outputs.endswith(stop_str):\n            outputs = outputs[:-len(stop_str)]\n        outputs = outputs.strip()\n\n        # prompt for answer\n        if args.answer_prompter:\n            outputs_reasoning = outputs\n            input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n\n            with torch.inference_mode():\n                output_ids = model.generate(\n                    input_ids,\n                    images=images,\n                    do_sample=True,\n                    temperature=0.2,\n                    max_new_tokens=64,\n                    use_cache=True,\n                    stopping_criteria=[stopping_criteria])\n\n            input_token_len = input_ids.shape[1]\n            n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n            if n_diff_input_output > 0:\n                print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n            outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n            outputs = outputs.strip()\n            if outputs.endswith(stop_str):\n                outputs = outputs[:-len(stop_str)]\n            outputs = outputs.strip()\n            outputs = outputs_reasoning + '\\n The answer is ' + outputs\n\n        ans_id = shortuuid.uuid()\n        ans_file.write(json.dumps({\"question_id\": idx,\n                                   \"prompt\": cur_prompt,\n                                   \"text\": outputs,\n                                   \"answer_id\": ans_id,\n                                   \"model_id\": model_name,\n                                   \"metadata\": {}}) + \"\\n\")\n        ans_file.flush()\n    ans_file.close()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-path\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--model-base\", type=str, default=None)\n    parser.add_argument(\"--image-folder\", type=str, default=\"\")\n    parser.add_argument(\"--question-file\", type=str, default=\"tables/question.json\")\n    parser.add_argument(\"--answers-file\", type=str, default=\"answer.jsonl\")\n    parser.add_argument(\"--conv-mode\", type=str, default=\"llava_v0\")\n    parser.add_argument(\"--num-chunks\", type=int, default=1)\n    parser.add_argument(\"--chunk-idx\", type=int, default=0)\n    parser.add_argument(\"--answer-prompter\", action=\"store_true\")\n    args = parser.parse_args()\n\n    eval_model(args)\n"
  },
  {
    "path": "llava/eval/qa_baseline_gpt35.py",
    "content": "\"\"\"Generate answers with GPT-3.5\"\"\"\n# Note: you need to be using OpenAI Python v0.27.0 for the code below to work\nimport argparse\nimport json\nimport os\nimport time\nimport concurrent.futures\n\nimport openai\nimport tqdm\nimport shortuuid\n\nMODEL = 'gpt-3.5-turbo'\nMODEL_ID = 'gpt-3.5-turbo:20230327'\n\ndef get_answer(question_id: int, question: str, max_tokens: int):\n    ans = {\n        'answer_id': shortuuid.uuid(),\n        'question_id': question_id,\n        'model_id': MODEL_ID,\n    }\n    for _ in range(3):\n        try:\n            response = openai.ChatCompletion.create(\n                model=MODEL,\n                messages=[{\n                    'role': 'system',\n                    'content': 'You are a helpful assistant.'\n                }, {\n                    'role': 'user',\n                    'content': question,\n                }],\n                max_tokens=max_tokens,\n            )\n            ans['text'] = response['choices'][0]['message']['content']\n            return ans\n        except Exception as e:\n            print('[ERROR]', e)\n            ans['text'] = '#ERROR#'\n            time.sleep(1)\n    return ans\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ChatGPT answer generation.')\n    parser.add_argument('-q', '--question')\n    parser.add_argument('-o', '--output')\n    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')\n    args = parser.parse_args()\n\n    questions_dict = {}\n    with open(os.path.expanduser(args.question)) as f:\n        for line in f:\n            if not line:\n                continue\n            q = json.loads(line)\n            questions_dict[q['question_id']] = q['text']\n\n    answers = []\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:\n        futures = []\n        for qid, question in questions_dict.items():\n            future = executor.submit(get_answer, qid, question, args.max_tokens)\n            futures.append(future)\n\n        for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):\n            answers.append(future.result())\n\n    answers.sort(key=lambda x: x['question_id'])\n\n    with open(os.path.expanduser(args.output), 'w') as f:\n        table = [json.dumps(ans) for ans in answers]\n        f.write('\\n'.join(table))\n"
  },
  {
    "path": "llava/eval/run_llava.py",
    "content": "import argparse\nimport torch\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom llava.conversation import conv_templates, SeparatorStyle\nfrom llava.model.builder import load_pretrained_model\nfrom llava.utils import disable_torch_init\nfrom llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n\nfrom PIL import Image\n\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\n\n\ndef load_image(image_file):\n    if image_file.startswith('http') or image_file.startswith('https'):\n        response = requests.get(image_file)\n        image = Image.open(BytesIO(response.content)).convert('RGB')\n    else:\n        image = Image.open(image_file).convert('RGB')\n    return image\n\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n\n    model_name = get_model_name_from_path(args.model_path)\n    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)\n\n    qs = args.query\n    if model.config.mm_use_im_start_end:\n        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs\n    else:\n        qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n\n    if 'llama-2' in model_name.lower():\n        conv_mode = \"llava_llama_2\"\n    elif \"v1\" in model_name.lower():\n        conv_mode = \"llava_v1\"\n    elif \"mpt\" in model_name.lower():\n        conv_mode = \"mpt\"\n    else:\n        conv_mode = \"llava_v0\"\n\n    if args.conv_mode is not None and conv_mode != args.conv_mode:\n        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))\n    else:\n        args.conv_mode = conv_mode\n\n    conv = conv_templates[args.conv_mode].copy()\n    conv.append_message(conv.roles[0], qs)\n    conv.append_message(conv.roles[1], None)\n    prompt = conv.get_prompt()\n\n    image = load_image(args.image_file)\n    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n\n    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n\n    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n    keywords = [stop_str]\n    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n\n    with torch.inference_mode():\n        output_ids = model.generate(\n            input_ids,\n            images=image_tensor,\n            do_sample=True,\n            temperature=0.2,\n            max_new_tokens=1024,\n            use_cache=True,\n            stopping_criteria=[stopping_criteria])\n\n    input_token_len = input_ids.shape[1]\n    n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n    if n_diff_input_output > 0:\n        print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n    outputs = outputs.strip()\n    if outputs.endswith(stop_str):\n        outputs = outputs[:-len(stop_str)]\n    outputs = outputs.strip()\n    print(outputs)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-path\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--model-base\", type=str, default=None)\n    parser.add_argument(\"--image-file\", type=str, required=True)\n    parser.add_argument(\"--query\", type=str, required=True)\n    parser.add_argument(\"--conv-mode\", type=str, default=None)\n    args = parser.parse_args()\n\n    eval_model(args)\n"
  },
  {
    "path": "llava/eval/summarize_gpt_review.py",
    "content": "import json\nimport os\nfrom collections import defaultdict\n\nimport numpy as np\n\nimport argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')\n    parser.add_argument('-d', '--dir', default=None)\n    parser.add_argument('-f', '--files', nargs='*', default=None)\n    parser.add_argument('-i', '--ignore', nargs='*', default=None)\n    return parser.parse_args()\n\n\nif __name__ == '__main__':\n    args = parse_args()\n\n    if args.ignore is not None:\n        args.ignore = [int(x) for x in args.ignore]\n\n    if args.files is not None and len(args.files) > 0:\n        review_files = args.files\n    else:\n        review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))]\n\n    for review_file in sorted(review_files):\n        config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')\n        scores = defaultdict(list)\n        print(config)\n        with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:\n            for review_str in f:\n                review = json.loads(review_str)\n                if args.ignore is not None and review['question_id'] in args.ignore:\n                    continue\n                if 'category' in review:\n                    scores[review['category']].append(review['tuple'])\n                    scores['all'].append(review['tuple'])\n                else:\n                    if 'tuple' in review:\n                        scores['all'].append(review['tuple'])\n                    else:\n                        scores['all'].append(review['score'])\n        for k, v in sorted(scores.items()):\n            stats = np.asarray(v).mean(0).tolist()\n            stats = [round(x, 3) for x in stats]\n            # print(k, stats, round(stats[1]/stats[0]*100, 1))\n            print(k, round(stats[1]/stats[0]*100, 1))\n        print('=================================')\n"
  },
  {
    "path": "llava/eval/webpage/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</title>\n    <link rel=\"stylesheet\" href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css\">\n    <link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/icon?family=Material+Icons\">\n    <link rel=\"stylesheet\" href=\"styles.css\">\n</head>\n\n<body>\n    <nav class=\"navbar navbar-expand-lg navbar-dark bg-dark\">\n        <a class=\"navbar-brand\" href=\"#\">🏔️ Vicuna Evaluation Examples</a>\n        <button class=\"navbar-toggler\" type=\"button\" data-toggle=\"collapse\" data-target=\"#navbarNav\" aria-controls=\"navbarNav\" aria-expanded=\"false\" aria-label=\"Toggle navigation\">\n          <span class=\"navbar-toggler-icon\"></span>\n        </button>\n        <div class=\"collapse navbar-collapse\" id=\"navbarNav\">\n          <ul class=\"navbar-nav mr-auto\">\n            <li class=\"nav-item\">\n                <a class=\"nav-link\" href=\"https://chat.lmsys.org/\">Demo</a>\n              </li>\n              <li class=\"nav-item\">\n                <a class=\"nav-link\" href=\"https://vicuna.lmsys.org\">Blog</a>\n              </li>\n              <li class=\"nav-item\">\n                <a class=\"nav-link\" href=\"https://github.com/lm-sys/FastChat\">Github</a>\n              </li>\n          </ul>\n        </div>\n    </nav>\n\n    <div class=\"container mt-5\">\n        <h2 class=\"text-center mb-5\">Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</h2>\n\n        <!-- Selection -->\n        <div class=\"form-row\">\n            <div class=\"form-group col-md-2\">\n                <label for=\"category-select\">Category</label>\n                <select class=\"form-control\" id=\"category-select\"></select>\n            </div>\n            <div class=\"form-group col-md-8\">\n                <label for=\"question-select\">Question</label>\n                <select class=\"form-control\" id=\"question-select\"></select>\n            </div>\n            <div class=\"form-group col-md-2\">\n                <div class=\"col-md-2\"><label>&nbsp;</label></div>\n                <div class=\"btn-group\" role=\"group\" aria-label=\"Left and Right Controller\">\n                    <button type=\"button\" class=\"form-control btn btn-primary\" id=\"prev-question\"><i class=\"material-icons\">keyboard_arrow_left</i></button>\n                    <button type=\"button\" class=\"form-control btn btn-primary\" id=\"next-question\"><i class=\"material-icons\">keyboard_arrow_right</i></button>\n                </div>\n            </div>\n        </div>\n\n        <!-- \"Battle\" -->\n        <div class=\"row mb-4\" style=\"justify-content: center;\">\n            <div class=\"col\" style=\"display: flex; justify-content: center; align-items: center;\">\n                <label class=\"adjustable-font-size\" id=\"other-score-label\">*/10</label>\n            </div>\n            <div class=\"col\">\n                <div class=\"vertical-flex-layout\">\n                    <img class=\"shadow figure-img img-fluid\" src=\"\" alt=\"other logo\" width=\"150\" id=\"other-model-figure\">\n                </div>\n            </div>\n            <div class=\"col\">\n                <div class=\"vertical-flex-layout\">\n                    <!-- from: https://fonts.google.com/icons?icon.query=battle&selected=Material+Symbols+Outlined:swords:FILL@0;wght@300;GRAD@0;opsz@48&icon.style=Outlined -->\n                    <img class=\"figure-img img-fluid\" src=\"figures/swords_FILL0_wght300_GRAD0_opsz48.svg\" width=\"60\" height=\"60\">\n                </div>\n            </div>\n            <div class=\"col\">\n                <div class=\"vertical-flex-layout\">\n                    <img class=\"shadow figure-img img-fluid\" src=\"figures/vicuna.jpeg\" alt=\"vicuna logo\" width=\"150\" id=\"our-model-figure\">\n                </div>\n            </div>\n            <div class=\"col\" style=\"display: flex; justify-content: center; align-items: center;\">\n                <label class=\"adjustable-font-size\" id=\"our-score-label\">*/10</label>\n            </div>\n        </div>\n\n        <!-- Question Card -->\n        <div class=\"card mb-4\">\n            <div class=\"card-body\" id=\"selected-question\"></div>\n        </div>\n\n        <!-- Answer Cards -->\n        <div class=\"row\">\n            <div class=\"col-md-6\">\n                <div class=\"card mb-4 expandable-card\">\n                    <div class=\"card-header\" style=\"padding-bottom: 0.2rem\" id=\"other-model-header-bg\">\n                        <div class=\"row\">\n                            <div class=\"col-md-5\" style=\"align-items: center; display: flex;\">\n                                <label id=\"other-model-header\">Assistant #1</label>\n                            </div>\n                            <div class=\"col-md-7\">\n                                <select class=\"form-control\" id=\"model-select\" style=\"height: fit-content; margin-top: -0.3rem;\"></select>\n                            </div>\n                        </div>\n                    </div>\n                    <div class=\"card-body\">\n                        <div class=\"card-text-container\">\n                            <div class=\"card-text\" id=\"other-model-answer\"></div>\n                        </div>\n                        <div class=\"btn btn-primary expand-btn\" style=\"display:flex;\"></div>\n                    </div>\n                </div>\n            </div>\n            <div class=\"col-md-6\">\n                <div class=\"card mb-4 expandable-card\">\n                    <div class=\"card-header\" id=\"our-model-header\">\n                        Assistant #2 (Vicuna, our model)\n                    </div>\n                    <div class=\"card-body\">\n                        <div class=\"card-text-container\">\n                            <div class=\"card-text\" id=\"our-model-answer\"></div>\n                        </div>\n                        <div class=\"btn btn-primary expand-btn\" style=\"display:flex;\"></div>\n                    </div>\n                </div>\n            </div>\n        </div>\n\n        <!-- Evaluation -->\n        <div class=\"card expandable-card\">\n            <div class=\"card-header\" style=\"background-color: #c9c9f2;\" id=\"evaluation-header\">GPT-4 Evaluation</div>\n            <div class=\"card-body\">\n                <div class=\"card-text-container\">\n                    <div class=\"card-text\" id=\"evaluation-result\"></div>\n                </div>\n                <div class=\"btn btn-primary expand-btn\" style=\"display:flex;\"></div>\n            </div>\n        </div>\n    </div>\n\n    <div class=\"container-fluid bg-light py-2\">\n        <div class=\"text-center\">\n            <small class=\"text-muted\">This website is co-authored with <a href=\"https://openai.com\" target=\"_blank\">GPT-4</a>.</small>\n        </div>\n    </div>\n\n    <!-- Marked.js -->\n    <script src=\"https://cdn.jsdelivr.net/npm/marked@4.3.0/lib/marked.umd.min.js\"></script>\n    <!-- Bootstrap and Popper.js JavaScript dependencies -->\n    <script src=\"https://code.jquery.com/jquery-3.5.1.slim.min.js\"></script>\n    <script src=\"https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js\"></script>\n    <script src=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js\"></script>\n\n    <script src=\"script.js\"></script>\n    <script>\n      // Fetch the JSON file\n      fetch('data.json')\n        .then(response => response.json())\n        .then(json_data => {\n            // Populate the models and questions.\n            populateModels(json_data.models);\n            populateQuestions(json_data.questions);\n            displayQuestion(currentQuestionIndex);\n        }).catch(error => console.error(error));\n    </script>\n</body>\n\n</html>\n"
  },
  {
    "path": "llava/eval/webpage/script.js",
    "content": "// Description: Script for the evaluation webpage.\n\nlet currentQuestionIndex = 1;\n\n// Store the model name mapping for later use.\nmodelNameMapping = {\n    \"gpt35\": \"ChatGPT-3.5\",\n    \"gpt4\": \"GPT-4\",\n    \"alpaca\": \"Alpaca-13b\",\n    \"vicuna\": \"Vicuna-13b\",\n    \"llama\": \"LLaMA-13b\",\n    \"bard\": \"Bard\",\n};\n\nmodelFigureMapping = {\n    \"vicuna\": \"figures/vicuna.jpeg\",\n    // Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg\n    \"gpt35\": \"figures/chatgpt.svg\",\n    // Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/\n    \"bard\": \"figures/bard.jpg\",\n    // Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html\n    \"alpaca\": \"figures/alpaca.png\",\n    // Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg\n    \"llama\": \"figures/llama.jpg\",\n}\n\n// Store the question data in a mapping for later use.\nquestionMapping = {};\n// Store the question ids in a mapping for later use.\ncategoryMapping = {};\n// Store the number of questions for later use.\nquestionsCount = 0;\n\n\nfunction text2Markdown(text) {\n    // Normalize the text for markdown rendering.\n    text = text.trim().replaceAll('\\n\\n', '\\n').replaceAll('\\n', '\\n\\n');\n    return marked.parse(text);\n}\n\nfunction capitalizeFirstChar(str) {\n    if (!str || str.length === 0) {\n      return str;\n    }\n    return str.charAt(0).toUpperCase() + str.slice(1);\n}\n\nfunction updateQuestionSelect(question_id) {\n    const select = document.getElementById('question-select');\n    // Clear the question select.\n    select.innerHTML = '';\n    // Populate the question select.\n    category = questionMapping[question_id].category;\n    categoryMapping[category].forEach(question_id => {\n        const question = questionMapping[question_id];\n        const option = document.createElement('option');\n        option.value = question_id;\n        option.textContent = 'Q' + question_id.toString() + ': ' + question.question;\n        select.appendChild(option);\n    });\n    select.value = question_id;\n}\n\nfunction updateModelSelect() {\n    const select = document.getElementById('model-select');\n    img_path = modelFigureMapping[select.value];\n    document.getElementById('other-model-figure').src = img_path;\n}\n\nfunction populateModels(models) {\n    const select = document.getElementById('model-select');\n    models.forEach(model => {\n        const option = document.createElement('option');\n        option.value = model;\n        option.textContent = modelNameMapping[model];\n        select.appendChild(option);\n    });\n    updateModelSelect();\n}\n\nfunction populateQuestions(questions) {\n    const category_select = document.getElementById('category-select');\n\n    questionsCount = questions.length;\n    questions.forEach(question => {\n        const option = document.createElement('option');\n        // Store the question data in a mapping for later use.\n        questionMapping[question.id] = {\n            category: question.category,\n            question: question.question,\n            answers: question.answers,\n            evaluations: question.evaluations,\n            scores: question.scores,\n        };\n        // Store the question id in the category mapping.\n        if (question.category in categoryMapping) {\n            categoryMapping[question.category].push(question.id);\n        } else {\n            categoryMapping[question.category] = [question.id];\n            const category_option = document.createElement('option');\n            category_option.value = question.category;\n            category_option.textContent = capitalizeFirstChar(question.category);\n            category_select.appendChild(category_option);\n        }\n    });\n    // Set the default category.\n    updateQuestionSelect(currentQuestionIndex);\n}\n\nfunction displayQuestion(index) {\n    const question = questionMapping[index].question;\n    document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question);\n    displayAnswers(index);\n}\n\nfunction displayAnswers(index) {\n    const question = questionMapping[index];\n    const otherModel = document.getElementById('model-select').value;\n    // render the answers with markdown\n    document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]);\n    document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna);\n\n    // Display evaluation\n    score = question.scores[otherModel];\n    score_text = modelNameMapping[otherModel] + \" \" + score[0] + \"/10, Vicuna-13b \" + score[1] + \"/10\";\n    document.getElementById('evaluation-header').textContent = \"GPT-4 Evaluation\" + \" (Score: \" + score_text + \")\";\n    document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]);\n\n    // Update model names\n    let assistant1_title = \"Assistant #1\"; // (\" + modelNameMapping[otherModel] + \")\";\n    let assistant2_title = \"Assistant #2 (Vicuna-13b, our model)\";\n    // Update scores/labels.\n    let assistant1_score_label = score[0].toString() + '/10';\n    let assistant2_score_label = score[1].toString() + '/10';\n\n    const colorRed ='#fa9'; // '#eb978d';\n    // const colorGreen = '#c9f2c9';\n    const colorBlue = '#8ef'; // '#71dbf9';\n    const colorYellow = '#fe7'; // '#fada57';\n    let otherModelHeaderColor = '';\n    let ourModelHeaderColor = '';\n    // Update the winner.\n    if (score[0] == score[1]) {\n        assistant1_title = '🏆 ' + assistant1_title;\n        assistant1_score_label = '🏆 ' + assistant1_score_label;\n        assistant2_title = '🏆 ' + assistant2_title;\n        assistant2_score_label = '🏆 ' + assistant2_score_label;\n        otherModelHeaderColor = colorYellow;\n        ourModelHeaderColor = colorYellow;\n    } else if (score[0] > score[1]) {\n        assistant1_title = '🏆 ' + assistant1_title;\n        assistant1_score_label = '🏆 ' + assistant1_score_label;\n        otherModelHeaderColor = colorBlue;\n        ourModelHeaderColor = colorRed;\n    } else if (score[0] < score[1]) {\n        assistant2_title = '🏆 ' + assistant2_title;\n        assistant2_score_label = '🏆 ' + assistant2_score_label;\n        otherModelHeaderColor = colorRed;\n        ourModelHeaderColor = colorBlue;\n    }\n\n    document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor;\n    document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor;\n\n    document.getElementById('other-model-header').textContent = assistant1_title;\n    document.getElementById('our-model-header').textContent = assistant2_title;\n\n    document.getElementById('other-score-label').textContent = assistant1_score_label;\n    document.getElementById('our-score-label').textContent = assistant2_score_label;\n\n    // Update expand buttons visibility for both cards after displaying answers\n    // Reset the expanded state and update expand buttons visibility for both cards after displaying answers\n    document.querySelectorAll('.expandable-card').forEach(card => {\n        card.classList.remove('expanded');\n        updateExpandButtonVisibility(card);\n        const expandBtn = card.querySelector('.expand-btn');\n        expandBtn.innerHTML = '<i class=\"material-icons\" style=\"pointer-events: none\">keyboard_arrow_down</i> Show more';   // .textContent = 'Show more';\n    });\n}\n\ndocument.getElementById('question-select').addEventListener('change', e => {\n    currentQuestionIndex = parseInt(e.target.value);\n    displayQuestion(currentQuestionIndex);\n});\n\ndocument.getElementById('category-select').addEventListener('change', e => {\n    let currentCategory = e.target.value;\n    const questionIds = categoryMapping[currentCategory];\n    currentQuestionIndex = questionIds[0];\n    updateQuestionSelect(currentQuestionIndex);\n    displayQuestion(currentQuestionIndex);\n});\n\n// Update expand buttons whenever the model is changed\ndocument.getElementById('model-select').addEventListener('change', () => {\n    displayAnswers(currentQuestionIndex);\n    document.querySelectorAll('.expandable-card').forEach(card => {\n        updateExpandButtonVisibility(card);\n    });\n    updateModelSelect();\n});\n\nfunction switchQuestionAndCategory() {\n    document.getElementById('question-select').value = currentQuestionIndex;\n    old_category = document.getElementById('category-select').value;\n    new_category = questionMapping[currentQuestionIndex].category;\n    if (old_category != new_category) {\n        document.getElementById('category-select').value = new_category;\n        updateQuestionSelect(currentQuestionIndex);\n    }\n    displayQuestion(currentQuestionIndex);\n}\n\ndocument.getElementById('prev-question').addEventListener('click', () => {\n    // Question index starts from 1.\n    currentQuestionIndex = Math.max(1, currentQuestionIndex - 1);\n    switchQuestionAndCategory();\n});\n\ndocument.getElementById('next-question').addEventListener('click', () => {\n    // Question index starts from 1.\n    currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1);\n    switchQuestionAndCategory();\n});\n\nfunction updateExpandButtonVisibility(card) {\n    const cardTextContainer = card.querySelector('.card-text-container');\n    const expandBtn = card.querySelector('.expand-btn');\n    if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) {\n        expandBtn.style.display = 'flex';\n    } else {\n        expandBtn.style.display = 'none';\n        card.classList.add('expanded');\n    }\n}\n\ndocument.querySelectorAll('.expand-btn').forEach(btn => {\n    btn.addEventListener('click', e => {\n        const card = e.target.closest('.expandable-card');\n        card.classList.toggle('expanded');\n        const more = '<i class=\"material-icons\" style=\"pointer-events: none\">keyboard_arrow_down</i> Show more';\n        const less = '<i class=\"material-icons\" style=\"pointer-events: none\">keyboard_arrow_up</i> Show less';\n        e.target.innerHTML = card.classList.contains('expanded') ? less : more;\n    });\n});\n"
  },
  {
    "path": "llava/eval/webpage/styles.css",
    "content": "body {\n    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n    background-color: #f8f9fa;\n}\n\n.navbar-dark .navbar-nav .nav-link {\n    color: #f1cf68;\n    font-size: 1.1rem;\n    padding: 0.5rem 0.6rem;\n}\n\n.card-header {\n    font-weight: bold;\n}\n\n.card {\n    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);\n    transition: 0.3s;\n}\n\n.card:hover {\n    box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);\n}\n\nbutton {\n    transition: background-color 0.3s;\n}\n\nbutton:hover {\n    background-color: #007bff;\n}\n\n@media (max-width: 767px) {\n    .form-row .form-group {\n        margin-bottom: 10px;\n    }\n}\n\n/* Extra styles */\n\n.expandable-card .card-text-container {\n    max-height: 200px;\n    overflow-y: hidden;\n    position: relative;\n}\n\n.expandable-card.expanded .card-text-container {\n    max-height: none;\n}\n\n.expand-btn {\n    position: relative;\n    display: none;\n    background-color: rgba(255, 255, 255, 0.8);\n    color: #510c75;\n    border-color: transparent;\n}\n\n.expand-btn:hover {\n    background-color: rgba(200, 200, 200, 0.8);\n    text-decoration: none;\n    border-color: transparent;\n    color: #510c75;\n}\n\n.expand-btn:focus {\n    outline: none;\n    text-decoration: none;\n}\n\n.expandable-card:not(.expanded) .card-text-container:after {\n    content: \"\";\n    position: absolute;\n    bottom: 0;\n    left: 0;\n    width: 100%;\n    height: 90px;\n    background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));\n}\n\n.expandable-card:not(.expanded) .expand-btn {\n    margin-top: -40px;\n}\n\n.card-body {\n    padding-bottom: 5px;\n}\n\n.vertical-flex-layout {\n    justify-content: center;\n    align-items: center;\n    height: 100%;\n    display: flex;\n    flex-direction: column;\n    gap: 5px;\n}\n\n.figure-img {\n    max-width: 100%;\n    height: auto;\n}\n\n.adjustable-font-size {\n    font-size: calc(0.5rem + 2vw);\n}\n"
  },
  {
    "path": "llava/mm_utils.py",
    "content": "from PIL import Image\nfrom io import BytesIO\nimport base64\n\nimport torch\nfrom transformers import StoppingCriteria\nfrom llava.constants import IMAGE_TOKEN_INDEX\n\n\ndef load_image_from_base64(image):\n    return Image.open(BytesIO(base64.b64decode(image)))\n\n\ndef process_images(images, image_processor, model_cfg):\n    return image_processor(images, return_tensors='pt')['pixel_values']\n\n\ndef tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):\n    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]\n\n    def insert_separator(X, sep):\n        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]\n\n    input_ids = []\n    offset = 0\n    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:\n        offset = 1\n        input_ids.append(prompt_chunks[0][0])\n\n    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):\n        input_ids.extend(x[offset:])\n\n    if return_tensors is not None:\n        if return_tensors == 'pt':\n            return torch.tensor(input_ids, dtype=torch.long)\n        raise ValueError(f'Unsupported tensor type: {return_tensors}')\n    return input_ids\n\ndef tokenizer_image_token_inter(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):\n    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]\n\n    def insert_separator(X, sep):\n        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]\n\n    input_ids = []\n    offset = 0\n    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:\n        offset = 1\n        input_ids.append(prompt_chunks[0][0])\n\n    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):\n        input_ids.extend(x[offset:])\n\n    if return_tensors is not None:\n        if return_tensors == 'pt':\n            return torch.tensor(input_ids, dtype=torch.long)\n        raise ValueError(f'Unsupported tensor type: {return_tensors}')\n    return input_ids\n\ndef get_model_name_from_path(model_path):\n    model_path = model_path.strip(\"/\")\n    model_paths = model_path.split(\"/\")\n    if model_paths[-1].startswith('checkpoint-'):\n        return model_paths[-2] + \"_\" + model_paths[-1]\n    else:\n        return model_paths[-1]\n\n\n\n\nclass KeywordsStoppingCriteria(StoppingCriteria):\n    def __init__(self, keywords, tokenizer, input_ids):\n        self.keywords = keywords\n        self.keyword_ids = []\n        for keyword in keywords:\n            cur_keyword_ids = tokenizer(keyword).input_ids\n            if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:\n                cur_keyword_ids = cur_keyword_ids[1:]\n            self.keyword_ids.append(torch.tensor(cur_keyword_ids))\n        self.tokenizer = tokenizer\n        self.start_len = input_ids.shape[1]\n\n    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        assert output_ids.shape[0] == 1, \"Only support batch size 1 (yet)\"  # TODO\n        offset = min(output_ids.shape[1] - self.start_len, 3)\n        self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]\n        for keyword_id in self.keyword_ids:\n            if output_ids[0, -keyword_id.shape[0]:] == keyword_id:\n                return True\n        outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]\n        for keyword in self.keywords:\n            if keyword in outputs:\n                return True\n        return False\n"
  },
  {
    "path": "llava/model/__init__.py",
    "content": "from .language_model.llava_llama_gd import LlavaLlamaForCausalLM,LlavaLlamaForCausalLM_gd,LlavaLlamaForCausalLM_joint,LlavaLlamaForCausalLM_joint_2st, LlavaConfig\\\n,LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr\nfrom .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig"
  },
  {
    "path": "llava/model/apply_delta.py",
    "content": "\"\"\"\nUsage:\npython3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta\n\"\"\"\nimport argparse\n\nimport torch\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom llava import LlavaLlamaForCausalLM\n\n\ndef apply_delta(base_model_path, target_model_path, delta_path):\n    print(\"Loading base model\")\n    base = AutoModelForCausalLM.from_pretrained(\n        base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)\n\n    print(\"Loading delta\")\n    delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)\n    delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)\n\n    print(\"Applying delta\")\n    for name, param in tqdm(delta.state_dict().items(), desc=\"Applying delta\"):\n        if name not in base.state_dict():\n            assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'\n            continue\n        if param.data.shape == base.state_dict()[name].shape:\n            param.data += base.state_dict()[name]\n        else:\n            assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \\\n                f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'\n            bparam = base.state_dict()[name]\n            param.data[:bparam.shape[0], :bparam.shape[1]] += bparam\n\n    print(\"Saving target model\")\n    delta.save_pretrained(target_model_path)\n    delta_tokenizer.save_pretrained(target_model_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--base-model-path\", type=str, required=True)\n    parser.add_argument(\"--target-model-path\", type=str, required=True)\n    parser.add_argument(\"--delta-path\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    apply_delta(args.base_model_path, args.target_model_path, args.delta_path)\n"
  },
  {
    "path": "llava/model/builder.py",
    "content": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\n\nimport os\nimport shutil\n\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig\nimport torch\nfrom llava.model import *\nfrom llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n\ndef load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map=\"auto\"):\n    kwargs = {\"device_map\": device_map}\n\n    if load_8bit:\n        kwargs['load_in_8bit'] = True\n    elif load_4bit:\n        kwargs['load_in_4bit'] = True\n        kwargs['quantization_config'] = BitsAndBytesConfig(\n            load_in_4bit=True,\n            bnb_4bit_compute_dtype=torch.float16,\n            bnb_4bit_use_double_quant=True,\n            bnb_4bit_quant_type='nf4'\n        )\n    else:\n        kwargs['torch_dtype'] = torch.float16\n\n    if 'llava' in model_name.lower():\n        # Load LLaVA model\n        if 'lora' in model_name.lower() and model_base is not None:\n            lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)\n            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n            print('Loading LLaVA from base model...')\n            model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)\n            token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features\n            if model.lm_head.weight.shape[0] != token_num:\n                model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n                model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))\n\n            print('Loading additional LLaVA weights...')\n            if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):\n                non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')\n            else:\n                # this is probably from HF Hub\n                from huggingface_hub import hf_hub_download\n                def load_from_hf(repo_id, filename, subfolder=None):\n                    cache_file = hf_hub_download(\n                        repo_id=repo_id,\n                        filename=filename,\n                        subfolder=subfolder)\n                    return torch.load(cache_file, map_location='cpu')\n                non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')\n            non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}\n            if any(k.startswith('model.model.') for k in non_lora_trainables):\n                non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}\n            model.load_state_dict(non_lora_trainables, strict=False)\n\n            from peft import PeftModel\n            print('Loading LoRA weights...')\n            model = PeftModel.from_pretrained(model, model_path)\n            print('Merging LoRA weights...')\n            model = model.merge_and_unload()\n            print('Model is loaded...')\n        elif model_base is not None:\n            # this may be mm projector only\n            print('Loading LLaVA from base model...')\n            if 'mpt' in model_name.lower():\n                if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):\n                    shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)\n                cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n                model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n            else:\n                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n                cfg_pretrained = AutoConfig.from_pretrained(model_path)\n                model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)\n\n            mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')\n            mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}\n            model.load_state_dict(mm_projector_weights, strict=False)\n        else:\n            if 'mpt' in model_name.lower():\n                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n            else:\n                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n    else:\n        # Load language model\n        if model_base is not None:\n            # PEFT model\n            from peft import PeftModel\n            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)\n            model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=\"auto\")\n            print(f\"Loading LoRA weights from {model_path}\")\n            model = PeftModel.from_pretrained(model, model_path)\n            print(f\"Merging weights\")\n            model = model.merge_and_unload()\n            print('Convert to FP16...')\n            model.to(torch.float16)\n        else:\n            use_fast = False\n            if 'mpt' in model_name.lower():\n                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)\n            else:\n                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n\n    image_processor = None\n\n    if 'llava' in model_name.lower():\n        mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n        mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n        if mm_use_im_patch_token:\n            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n        if mm_use_im_start_end:\n            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n        model.resize_token_embeddings(len(tokenizer))\n\n        vision_tower = model.get_vision_tower()\n        if not vision_tower.is_loaded:\n            vision_tower.load_model()\n        vision_tower.to(device='cuda', dtype=torch.float16)\n        image_processor = vision_tower.image_processor\n\n    if hasattr(model.config, \"max_sequence_length\"):\n        context_len = model.config.max_sequence_length\n    else:\n        context_len = 2048\n\n    return tokenizer, model, image_processor, context_len\n"
  },
  {
    "path": "llava/model/consolidate.py",
    "content": "\"\"\"\nUsage:\npython3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate\n\"\"\"\nimport argparse\n\nimport torch\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom llava.model import *\nfrom llava.model.utils import auto_upgrade\n\n\ndef consolidate_ckpt(src_path, dst_path):\n    print(\"Loading model\")\n    auto_upgrade(src_path)\n    src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)\n    src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)\n    src_model.save_pretrained(dst_path)\n    src_tokenizer.save_pretrained(dst_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--src\", type=str, required=True)\n    parser.add_argument(\"--dst\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    consolidate_ckpt(args.src, args.dst)\n"
  },
  {
    "path": "llava/model/language_model/llava_llama.py",
    "content": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport detectron2.utils.comm as comm\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, \\\n                         LlamaConfig, LlamaModel, LlamaForCausalLM\n\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM\n\n\nclass LlavaConfig(LlamaConfig):\n    model_type = \"llava\"\n\n\nclass LlavaLlamaModel(LlavaMetaModel, LlamaModel):\n    config_class = LlavaConfig\n\n    def __init__(self, config: LlamaConfig):\n        super(LlavaLlamaModel, self).__init__(config)\n\n\nclass LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):\n    config_class = LlavaConfig\n\n    def __init__(self, config):\n        super(LlamaForCausalLM, self).__init__(config)\n        self.model = LlavaLlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_model(self):\n        return self.model\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n        print(f\"rank: {comm.get_rank()}\",1)\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n        print(f\"rank: {comm.get_rank()}\",2)\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n        print(f\"rank: {comm.get_rank()}\",2)\n\n        if not return_dict:\n            print(f\"rank: {comm.get_rank()}\", 3)\n\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        print(f\"rank: {comm.get_rank()}\",4)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n                \"images\": kwargs.get(\"images\", None),\n            }\n        )\n        return model_inputs\n\nAutoConfig.register(\"llava\", LlavaConfig)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)\n"
  },
  {
    "path": "llava/model/language_model/llava_llama_gd.py",
    "content": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\n\nfrom typing import List, Optional, Tuple, Union\nIGNORE_INDEX=-100\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, \\\n                         LlamaConfig, LlamaModel, LlamaForCausalLM\n\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM, LlavaMetaForCausalLM_gd,LlavaMetaForCausalLM_gd_interactive\n\nimport transformers\n# @dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    # tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances,tokenizer):\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :tokenizer.model_max_length]\n        labels = labels[:, :tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(tokenizer.pad_token_id),\n        )\n\n        if 'image_clip' in instances[0]:\n            images = [instance['image_clip'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\nclass LlavaConfig(LlamaConfig):\n    model_type = \"llava\"\n\n\nclass LlavaLlamaModel(LlavaMetaModel, LlamaModel):\n    config_class = LlavaConfig\n\n    def __init__(self, config: LlamaConfig):\n        super(LlavaLlamaModel, self).__init__(config)\n\n\nclass LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):\n    config_class = LlavaConfig\n\n    def __init__(self, config):\n        super(LlamaForCausalLM, self).__init__(config)\n        self.model = LlavaLlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_model(self):\n        return self.model\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n                \"images\": kwargs.get(\"images\", None),\n            }\n        )\n        return model_inputs\n\nclass LlavaLlamaForCausalLM_gd(LlamaForCausalLM, LlavaMetaForCausalLM_gd):\n    config_class = LlavaConfig\n\n    def __init__(self, config):\n        super(LlamaForCausalLM, self).__init__(config)\n        self.model = LlavaLlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_model(self):\n        return self.model\n\n    def forward(self,**batched_inputs):\n        # print(kwargs.keys())\n        # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']])\n        collator=DataCollatorForSupervisedDataset()\n\n        if 'refcoco' in batched_inputs:\n            if 'vg' in batched_inputs:\n                llava_inputs = collator(batched_inputs['vg']+batched_inputs['refcoco'],\n                                        tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n            else:\n                llava_inputs = collator( batched_inputs['refcoco'],\n                                    tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n        elif 'coco' in batched_inputs:\n            llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        else:\n            llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        llava_inputs['seg_inputs']=batched_inputs\n        return self.forward_inner(**llava_inputs)\n\n    def forward_inner(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=None,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n        ground_idx_coco=[]\n        ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels]\n        if 'refcoco' in seg_inputs:\n            if 'vg' in seg_inputs:\n                vg_len=len(seg_inputs['vg'])\n                ground_idx_flickr = ground_idx[:vg_len]\n                padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr,\n                                                                           batch_first=True,\n                                                                           padding_value=-1)\n                padded_mask_flickr = padded_ground_idx_flickr != -1\n                padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0\n                # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n                hidden_states = outputs[0]\n                hidden_states_flickr = hidden_states[:vg_len]\n                ground_hs_flickr = torch.gather(hidden_states_flickr, 1,\n                                                padded_ground_idx_flickr[..., None].repeat(1, 1,\n                                                                                           hidden_states_flickr.shape[\n                                                                                               -1]))\n                seg_inputs['vg_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr)\n            flickr_len = len(seg_inputs['refcoco'])\n            ##########flickr\n            # if self.seg_model.model.coco_only:\n            ground_idx_flickr = ground_idx[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else ground_idx[:flickr_len]\n            padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr,\n                                                                       batch_first=True,\n                                                                       padding_value=-1)\n            padded_mask_flickr = padded_ground_idx_flickr != -1\n            padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0\n            # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n            hidden_states = outputs[0]\n            hidden_states_flickr = hidden_states[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else hidden_states[:flickr_len]\n            ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1,\n                                                                                                                hidden_states_flickr.shape[\n                                                                                                                    -1]))\n            seg_inputs['refcoco_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr)\n            # seg_inputs['flickr']=seg_inputs['refcoco']\n        else:\n            flickr_len=len(seg_inputs['flickr'])\n            ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels]\n            zero_mask = [0 if len(idx) == 0 else 1 for idx in ground_idx]\n            ##########flickr\n            # if self.seg_model.model.coco_only:\n            ground_idx_flickr=ground_idx[:flickr_len]\n            padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr,\n                                                     batch_first=True,\n                                                     padding_value=-1)\n            padded_mask_flickr=padded_ground_idx_flickr!=-1\n            padded_ground_idx_flickr[padded_ground_idx_flickr==-1]=0\n            # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n            hidden_states = outputs[0]\n            hidden_states_flickr=hidden_states[:flickr_len]\n            ground_hs_flickr=torch.gather(hidden_states_flickr,1,padded_ground_idx_flickr[...,None].repeat(1,1,hidden_states_flickr.shape[-1]))\n            seg_inputs['flickr_text_embeddings']=(ground_hs_flickr,padded_mask_flickr)\n\n            ##########coco\n            ground_idx_coco = ground_idx[flickr_len:]\n            if len(ground_idx_coco)>0:\n                for i,(idx,data) in enumerate(zip(ground_idx_coco,seg_inputs['coco'])):\n                    mask=data['grounding_mask']\n                    ground_idx_coco[i]=idx[mask[:len(idx)]]\n                padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco,\n                                                                           batch_first=True,\n                                                                           padding_value=-1)\n                padded_mask_coco = padded_ground_idx_coco != -1\n\n                padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0\n\n                hidden_states = outputs[0]\n                hidden_states_coco = hidden_states[flickr_len:]\n                ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1,\n                                                                                                                    hidden_states_coco.shape[\n                                                                                                                        -1]))\n                seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco)\n\n        ground_loss=self.seg_model(seg_inputs)\n        if self.seg_model.model.coco_only and len(ground_idx_coco)>0:\n            logits = self.lm_head(hidden_states_coco)\n        else:\n            logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            if self.seg_model.model.coco_only and len(ground_idx_coco) > 0:\n                shift_labels = labels[..., 1:][flickr_len:].contiguous()\n            else:\n                shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        ground_loss['llava']=loss\n        ground_loss['loss_total']=sum(ground_loss.values())\n        return CausalLMOutputWithPast(\n            loss=ground_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n                \"images\": kwargs.get(\"images\", None),\n            }\n        )\n        return model_inputs\n\n    def forward_eval(self, inputs):\n        collator=DataCollatorForSupervisedDataset()\n        llava_inputs=collator(inputs,tokenizer=inputs[0]['tokenizer'])\n        llava_inputs['seg_inputs']=inputs\n        return self.forward_inner_eval(**llava_inputs)\n\n    def forward_inner_eval(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n\n        output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0][\"tokenizer\"], return_dict)\n        output_text = seg_inputs[0][\"tokenizer\"].batch_decode([output_ids], skip_special_tokens=True)[0]\n        if len(seg_hidden_states)==0:\n            return output_text, [], []\n        seg_tokens = torch.cat(seg_hidden_states, dim=1)\n        padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0\n        predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask))\n\n        return output_text, predicted_boxes, predicted_masks\n    \n    def auto_regressive_generate(self, \n                        attention_mask,\n                        past_key_values,\n                        inputs_embeds,\n                        output_attentions,\n                        tokenizer,\n                        return_dict,\n                        temporature=0.0\n        ):\n        ########\n        # llm_inputs['obj_num'] = False\n        seg_token = tokenizer.encode(\"<seg>\")[1]\n        seg_token_list = []\n        output_ids = []\n        output_logits = []\n        length = inputs_embeds.shape[1]\n        for i in range(1000):\n            # import pdb;pdb.set_trace()\n            if i == 0:\n                results = self.model(\n                    input_ids=None,\n                    past_key_values=past_key_values,\n                    inputs_embeds=inputs_embeds,\n                    use_cache=True,\n                    output_attentions=output_attentions,\n                    output_hidden_states=True,\n                    return_dict=return_dict\n                )\n            else:\n                attention_mask = cur_hidden.new_ones(\n                    1, past_key_values[0][0].shape[-2] + 1, device=\"cuda\")\n                # print(\"Attention mask shape: \", attention_mask.shape)\n                results = self.model(\n                    input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device),\n                    attention_mask=attention_mask,\n                    past_key_values=past_key_values,\n                    # inputs_embeds=cur_hidden,\n                    use_cache=True,\n                    output_attentions=output_attentions,\n                    output_hidden_states=True,\n                    return_dict=return_dict\n                )\n            cur_hidden = results.hidden_states[-1][:, -1:]  # last layer last token\n            logits = self.lm_head(results[0])\n            cur_logits = logits[0][-1]\n            cur_id = int(torch.argmax(cur_logits))\n            if temporature < 1e-4:\n                cur_id = int(torch.argmax(cur_logits))\n            else:\n                probs = torch.softmax(cur_logits / temporature, dim=-1)\n                cur_id = int(torch.multinomial(probs, num_samples=1))\n                            \n            past_key_values = results.past_key_values\n            length += 1\n\n            if cur_id==seg_token:\n                seg_token_list.append(cur_hidden)\n            output_ids.append(cur_id)\n            output_logits.append(cur_logits)\n            if tokenizer.decode(output_ids).find(\"</s>\")!=-1:\n                break\n        return output_ids,seg_token_list\n    \nclass LlavaLlamaForCausalLM_joint(LlavaLlamaForCausalLM_gd):\n    def forward(self,**batched_inputs):\n        # print(kwargs.keys())\n        # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']])\n        collator=DataCollatorForSupervisedDataset()\n        assert 'refcoco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs\n        for data in batched_inputs['llava']:\n            data['image_clip']=data['image']\n        llava_inputs = collator( batched_inputs['flickr']+batched_inputs['refcoco']+batched_inputs['llava'],\n                                tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n\n        # if 'refcoco' in batched_inputs:\n        #     llava_inputs = collator( batched_inputs['refcoco'],\n        #                             tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n        # elif 'coco' in batched_inputs:\n        #     llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        # else:\n        #     llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        llava_inputs['seg_inputs']=batched_inputs\n        return self.forward_inner(**llava_inputs)\n\n    def forward_inner(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=None,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n        ground_idx_coco=[]\n        # if 'refcoco' in seg_inputs:\n        flickr_len = len(seg_inputs['flickr'])\n        ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels]\n        ##########flickr\n        # if self.seg_model.model.coco_only:\n        ground_idx_flickr = ground_idx[:flickr_len]\n        padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr,\n                                                                   batch_first=True,\n                                                                   padding_value=-1)\n        padded_mask_flickr = padded_ground_idx_flickr != -1\n        padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0\n        # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n        hidden_states = outputs[0]\n        hidden_states_flickr = hidden_states[:flickr_len]\n        ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1,\n                                                                                                            hidden_states_flickr.shape[\n                                                                                                                -1]))\n        seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr)\n        # seg_inputs['flickr']=seg_inputs['refcoco']\n        # else:\n        #################################################\n        #################################################\n        refcoco_len=len(seg_inputs['refcoco'])\n        ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels]\n        ##########flickr\n        ground_idx_refcoco=ground_idx[flickr_len:flickr_len+refcoco_len]\n        padded_ground_idx_refcoco = torch.nn.utils.rnn.pad_sequence(ground_idx_refcoco,\n                                                 batch_first=True,\n                                                 padding_value=-1)\n        padded_mask_refcoco=padded_ground_idx_refcoco!=-1\n        padded_ground_idx_refcoco[padded_ground_idx_refcoco==-1]=0\n        # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n        # hidden_states = outputs[0]\n        hidden_states_refcoco=hidden_states[flickr_len:flickr_len+refcoco_len]\n        ground_hs_refcoco=torch.gather(hidden_states_refcoco,1,padded_ground_idx_refcoco[...,None].repeat(1,1,hidden_states_refcoco.shape[-1]))\n        seg_inputs['refcoco_text_embeddings']=(ground_hs_refcoco,padded_mask_refcoco)\n\n\n\n        ground_loss=self.seg_model(seg_inputs)\n        # if self.seg_model.model.coco_only and len(ground_idx_coco)>0:\n        #     logits = self.lm_head(hidden_states_coco)\n        # else:\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            if self.seg_model.model.coco_only and len(ground_idx_coco) > 0:\n                shift_labels = labels[..., 1:][flickr_len:].contiguous()\n            else:\n                shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        ground_loss['llava']=loss\n        ground_loss['loss_total']=sum(ground_loss.values())\n        return CausalLMOutputWithPast(\n            loss=ground_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\nclass LlavaLlamaForCausalLM_joint_2st(LlavaLlamaForCausalLM_gd):\n    def forward(self,**batched_inputs):\n        # print(kwargs.keys())\n        # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']])\n        collator=DataCollatorForSupervisedDataset()\n        assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs\n        for data in batched_inputs['llava']:\n            data['image_clip']=data['image']\n        llava_inputs = collator( batched_inputs['flickr']+batched_inputs['coco']+batched_inputs['llava'],\n                                tokenizer=batched_inputs['coco'][0]['tokenizer'])\n\n        # if 'refcoco' in batched_inputs:\n        #     llava_inputs = collator( batched_inputs['refcoco'],\n        #                             tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n        # elif 'coco' in batched_inputs:\n        #     llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        # else:\n        #     llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        llava_inputs['seg_inputs']=batched_inputs\n        return self.forward_inner(**llava_inputs)\n\n    def forward_inner(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=None,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n\n        flickr_len = len(seg_inputs['flickr'])\n        ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels]\n        ##########flickr\n        # if self.seg_model.model.coco_only:\n        ground_idx_flickr = ground_idx[:flickr_len]\n        padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr,\n                                                                   batch_first=True,\n                                                                   padding_value=-1)\n        padded_mask_flickr = padded_ground_idx_flickr != -1\n        padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0\n        # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx]\n        if self.seg_model.model.detach_seg:\n            hidden_states = outputs[0].detach()\n        else:\n            hidden_states = outputs[0]\n        hidden_states_flickr = hidden_states[:flickr_len]\n        ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1,\n                                                                                                            hidden_states_flickr.shape[\n                                                                                                                -1]))\n        seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr)\n\n        ##########coco\n        coco_len = len(seg_inputs['coco'])\n        ground_idx_coco = ground_idx[flickr_len:flickr_len+coco_len]\n        if len(ground_idx_coco) > 0:\n            for i, (idx, data) in enumerate(zip(ground_idx_coco, seg_inputs['coco'])):\n                mask = data['grounding_mask']\n                ground_idx_coco[i] = idx[mask[:len(idx)]]\n            padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco,\n                                                                     batch_first=True,\n                                                                     padding_value=-1)\n            padded_mask_coco = padded_ground_idx_coco != -1\n\n            padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0\n\n            # hidden_states = outputs[0]\n            hidden_states_coco = hidden_states[flickr_len:flickr_len+coco_len]\n            ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1,\n                                                                                                          hidden_states_coco.shape[\n                                                                                                              -1]))\n            seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco)\n        ground_loss = self.seg_model(seg_inputs)\n        hidden_states_ = outputs[0]\n        if self.seg_model.model.coco_only and len(ground_idx_coco) > 0:\n            logits = self.lm_head(hidden_states_[flickr_len:])\n        else:\n            logits = self.lm_head(hidden_states_)\n        ############################################################\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            if self.seg_model.model.coco_only and len(ground_idx_coco) > 0:\n                shift_labels = labels[..., 1:][flickr_len:].contiguous()\n            else:\n                shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        ground_loss['llava']=loss\n        ground_loss['loss_total']=sum(ground_loss.values())\n        ignore_list=[f'_{i}' for i in range(1,10)]\n        ignore_list.append('interm')\n        for key in list(ground_loss.keys()):\n            if not key.endswith('_0') and key!='llava' and key !='loss_total':\n                ground_loss.pop(key)\n        return CausalLMOutputWithPast(\n            loss=ground_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\nclass LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr(LlamaForCausalLM, LlavaMetaForCausalLM_gd_interactive):\n    config_class = LlavaConfig\n\n    def __init__(self, config):\n        super(LlamaForCausalLM, self).__init__(config)\n        self.model = LlavaLlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_model(self):\n        return self.model\n\n    def forward(self,**batched_inputs):\n        # print(kwargs.keys())\n        # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']])\n        collator=DataCollatorForSupervisedDataset()\n        # assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs and 'interactive' in batched_inputs\n        # for data in batched_inputs['llava']:\n        #     data['image_clip']=data['image']\n        llava_inputs = collator( batched_inputs['interactive'],\n                                tokenizer=batched_inputs['interactive'][0]['tokenizer'])\n\n        # if 'refcoco' in batched_inputs:\n        #     llava_inputs = collator( batched_inputs['refcoco'],\n        #                             tokenizer=batched_inputs['refcoco'][0]['tokenizer'])\n        # elif 'coco' in batched_inputs:\n        #     llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        # else:\n        #     llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer'])\n        llava_inputs['seg_inputs']=batched_inputs\n        res1= self.forward_inner(**llava_inputs)\n        loss_dict1=res1.loss\n        prefix1='coco.'\n        res1.loss=res1['loss']={prefix1+k:v for k,v in loss_dict1.items()}\n        if 'interactiveref' in batched_inputs:\n            llava_inputs = collator( batched_inputs['interactiveref'],\n                                    tokenizer=batched_inputs['interactive'][0]['tokenizer'])\n            batched_inputs['interactive']=batched_inputs['interactiveref']\n            llava_inputs['seg_inputs']=batched_inputs\n            res2= self.forward_inner(**llava_inputs)\n            loss_dict2=res2.loss\n            prefix2='refcoco.'\n            res2.loss=res2['loss']={prefix2+k:v for k,v in loss_dict2.items()}\n            res1.loss.update(res2.loss)\n            res1.loss['loss_total']=res1.loss['coco.loss_total']+res1.loss['refcoco.loss_total']\n        else:\n            res1.loss['loss_total'] = res1.loss['coco.loss_total']\n        return res1\n\n\n    def forward_inner(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        obj_feats,inter_losses=self.interactive_model.float().forward(seg_inputs['interactive'],detach=False)\n        obj_feats=[obj_feats[i][seg_inputs['interactive'][i]['grounding_index']][None] for i in range(len(obj_feats))]\n        num_it=len(seg_inputs['interactive'])\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images,obj_feats=obj_feats,num_it=num_it)\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=None,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n\n  \n        hidden_states_ = outputs[0]\n\n        logits = self.lm_head(hidden_states_)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            # if self.seg_model.model.coco_only and len(ground_idx_coco) > 0:\n            #     shift_labels = labels[..., 1:][flickr_len:].contiguous()\n            # else:\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        ground_loss=dict()\n        ground_loss['llava']=loss\n        # for k,v in inter_losses.items():\n        #     print(v.dtype)\n        inter_losses={k:inter_losses[k].to(float) for k in inter_losses.keys()}\n        ground_loss.update(inter_losses)\n        # import pdb;pdb.set_trace()\n        ground_loss['loss_total']=sum(ground_loss.values())\n        ignore_list=[f'_{i}' for i in range(1,10)]\n        ignore_list.append('interm')\n        for key in list(ground_loss.keys()):\n            if not key.endswith('_0') and key!='llava' and key !='loss_total':\n                ground_loss.pop(key)\n        return CausalLMOutputWithPast(\n            loss=ground_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n    def forward_eval(self, batched_inputs):\n        if not (batched_inputs[0][\"points\"] is None):\n            print(\"Get Interactive Data.\")\n            collator=DataCollatorForSupervisedDataset()\n            llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer'])\n            llava_inputs['seg_inputs']=batched_inputs\n            if \"temporature\" in batched_inputs[0].keys():\n                llava_inputs[\"temporature\"] = batched_inputs[0][\"temporature\"]\n            else:\n                llava_inputs[\"temporature\"] = 0\n            return self.forward_inner_eval_interactive(**llava_inputs)\n        else:\n            print(\"Do not Get Interactive Data.\")\n            collator=DataCollatorForSupervisedDataset()\n            llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer'])\n            llava_inputs['seg_inputs']=batched_inputs\n            if \"temporature\" in batched_inputs[0].keys():\n                llava_inputs[\"temporature\"] = batched_inputs[0][\"temporature\"]\n            else:\n                llava_inputs[\"temporature\"] = 0\n            return self.forward_inner_eval(**llava_inputs)\n    def forward_inner_eval(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n        temporature=0\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal_NoInter(input_ids, attention_mask, past_key_values, labels, images)\n\n        output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0][\"tokenizer\"], return_dict, temporature)\n        output_text = seg_inputs[0][\"tokenizer\"].batch_decode([output_ids], skip_special_tokens=True)[0]\n        if len(seg_hidden_states)==0:\n            return output_text, [], [], None\n        seg_tokens = torch.cat(seg_hidden_states, dim=1)\n        padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0\n        predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask))\n\n        return output_text, predicted_boxes, predicted_masks, None\n    def forward_inner_eval_interactive(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        seg_inputs: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n        temporature=0\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        #! extra interaction part\n        boxes = seg_inputs[0]['points']\n        seg_inputs[0]['targets'] = [dict()]\n        seg_inputs[0]['targets'][0]['points'] = boxes\n        if seg_inputs[0]['mode_inter'].lower() == \"click\":\n            seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([0.0])\n        elif seg_inputs[0]['mode_inter'].lower() == \"box\":\n            seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([1.0])\n\n        seg_inputs[0]['targets'][0]['is_part'] = [0]\n        inter_masks, _, obj_feats =self.interactive_model.forward(seg_inputs)\n        num_it=len(seg_inputs)\n        #\n        _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, obj_feats=obj_feats,num_it=num_it)\n        \n        output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0][\"tokenizer\"], return_dict, temporature)\n        \n        output_text = seg_inputs[0][\"tokenizer\"].batch_decode([output_ids], skip_special_tokens=True)\n        if len(seg_hidden_states)==0:\n            return output_text[0], [], None, inter_masks\n        seg_tokens = torch.cat(seg_hidden_states, dim=1)\n        padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0\n        predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask))\n\n        return output_text[0], predicted_boxes, predicted_masks, inter_masks\n    def auto_regressive_generate(self, \n                        attention_mask,\n                        past_key_values,\n                        inputs_embeds,\n                        output_attentions,\n                        tokenizer,\n                        return_dict,\n                        temporature=0.0\n        ):\n        ########\n        # llm_inputs['obj_num'] = False\n        seg_token = tokenizer.encode(\"<seg>\")[1]\n        seg_token_list = []\n        output_ids = []\n        output_logits = []\n        length = inputs_embeds.shape[1]\n        for i in range(1000):\n            # import pdb;pdb.set_trace()\n            if i == 0:\n                results = self.model(\n                    input_ids=None,\n                    past_key_values=past_key_values,\n                    inputs_embeds=inputs_embeds,\n                    use_cache=True,\n                    output_attentions=output_attentions,\n                    output_hidden_states=True,\n                    return_dict=return_dict\n                )\n            else:\n                attention_mask = cur_hidden.new_ones(\n                    1, past_key_values[0][0].shape[-2] + 1, device=\"cuda\")\n                # print(\"Attention mask shape: \", attention_mask.shape)\n                results = self.model(\n                    input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device),\n                    attention_mask=attention_mask,\n                    past_key_values=past_key_values,\n                    # inputs_embeds=cur_hidden,\n                    use_cache=True,\n                    output_attentions=output_attentions,\n                    output_hidden_states=True,\n                    return_dict=return_dict\n                )\n            cur_hidden = results.hidden_states[-1][:, -1:]  # last layer last token\n            logits = self.lm_head(results[0])\n            cur_logits = logits[0][-1]\n            cur_id = int(torch.argmax(cur_logits))\n            if temporature < 1e-4:\n                cur_id = int(torch.argmax(cur_logits))\n            else:\n                probs = torch.softmax(cur_logits / temporature, dim=-1)\n                cur_id = int(torch.multinomial(probs, num_samples=1))\n                            \n            past_key_values = results.past_key_values\n            length += 1\n\n            if cur_id==seg_token:\n                seg_token_list.append(cur_hidden)\n            output_ids.append(cur_id)\n            output_logits.append(cur_logits)\n            if tokenizer.decode(output_ids).find(\"</s>\")!=-1:\n                break\n        return output_ids,seg_token_list\n   \nAutoConfig.register(\"llava\", LlavaConfig)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_gd)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st)\nAutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr)"
  },
  {
    "path": "llava/model/language_model/llava_mpt.py",
    "content": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\n\nfrom typing import List, Optional, Tuple\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nimport math\n\nfrom transformers import AutoConfig, AutoModelForCausalLM\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel\nfrom llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM\n\n\nclass LlavaMPTConfig(MPTConfig):\n    model_type = \"llava_mpt\"\n\n\nclass LlavaMPTModel(LlavaMetaModel, MPTModel):\n    config_class = LlavaMPTConfig\n\n    def __init__(self, config: MPTConfig):\n        config.hidden_size = config.d_model\n        super(LlavaMPTModel, self).__init__(config)\n    \n    def embed_tokens(self, x):\n        return self.wte(x)\n\n\nclass LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):\n    config_class = LlavaMPTConfig\n    supports_gradient_checkpointing = True\n\n    def __init__(self, config):\n        super(MPTForCausalLM, self).__init__(config)\n\n        if not config.tie_word_embeddings:\n            raise ValueError('MPTForCausalLM only supports tied word embeddings')\n        self.transformer = LlavaMPTModel(config)\n        self.logit_scale = None\n        if config.logit_scale is not None:\n            logit_scale = config.logit_scale\n            if isinstance(logit_scale, str):\n                if logit_scale == 'inv_sqrt_d_model':\n                    logit_scale = 1 / math.sqrt(config.d_model)\n                else:\n                    raise ValueError(f\"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.\")\n            self.logit_scale = logit_scale\n\n    def get_model(self):\n        return self.transformer\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, LlavaMPTModel):\n            module.gradient_checkpointing = value\n\n    def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):\n        return_dict = return_dict if return_dict is not None else self.config.return_dict\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)\n        outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)\n        # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338\n        logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)\n        if self.logit_scale is not None:\n            if self.logit_scale == 0:\n                warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')\n            logits *= self.logit_scale\n        loss = None\n        if labels is not None:\n            labels = torch.roll(labels, shifts=-1)\n            labels[:, -1] = -100\n            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))\n        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)\n\n    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):\n        if inputs_embeds is not None:\n            raise NotImplementedError('inputs_embeds is not implemented for MPT yet')\n        attention_mask = kwargs['attention_mask'].bool()\n        if attention_mask[:, -1].sum() != attention_mask.shape[0]:\n            raise NotImplementedError('MPT does not support generation with right padding.')\n        if self.transformer.attn_uses_sequence_id and self.training:\n            sequence_id = torch.zeros_like(input_ids[:1])\n        else:\n            sequence_id = None\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n        if self.transformer.prefix_lm:\n            prefix_mask = torch.ones_like(attention_mask)\n            if kwargs.get('use_cache') == False:\n                raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')\n        else:\n            prefix_mask = None\n        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), \"images\": kwargs.get(\"images\", None)}\n\n\nAutoConfig.register(\"llava_mpt\", LlavaMPTConfig)\nAutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)\n"
  },
  {
    "path": "llava/model/language_model/mpt/adapt_tokenizer.py",
    "content": "from typing import Union\nfrom transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast\nTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\nNUM_SENTINEL_TOKENS: int = 100\n\ndef adapt_tokenizer_for_denoising(tokenizer: Tokenizer):\n    \"\"\"Adds sentinel tokens and padding token (if missing).\n\n    Expands the tokenizer vocabulary to include sentinel tokens\n    used in mixture-of-denoiser tasks as well as a padding token.\n\n    All added tokens are added as special tokens. No tokens are\n    added if sentinel tokens and padding token already exist.\n    \"\"\"\n    sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]\n    tokenizer.add_tokens(sentinels_to_add, special_tokens=True)\n    if tokenizer.pad_token is None:\n        tokenizer.add_tokens('<pad>', special_tokens=True)\n        tokenizer.pad_token = '<pad>'\n        assert tokenizer.pad_token_id is not None\n    sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])\n    _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids\n    tokenizer.sentinel_token_ids = _sentinel_token_ids\n\nclass AutoTokenizerForMOD(AutoTokenizer):\n    \"\"\"AutoTokenizer + Adaptation for MOD.\n\n    A simple wrapper around AutoTokenizer to make instantiating\n    an MOD-adapted tokenizer a bit easier.\n\n    MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),\n    a padding token, and a property to get the token ids of the\n    sentinel tokens.\n    \"\"\"\n\n    @classmethod\n    def from_pretrained(cls, *args, **kwargs):\n        \"\"\"See `AutoTokenizer.from_pretrained` docstring.\"\"\"\n        tokenizer = super().from_pretrained(*args, **kwargs)\n        adapt_tokenizer_for_denoising(tokenizer)\n        return tokenizer"
  },
  {
    "path": "llava/model/language_model/mpt/attention.py",
    "content": "\"\"\"Attention layers.\"\"\"\nimport math\nimport warnings\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom packaging import version\nfrom torch import nn\nfrom .norm import LPLayerNorm\n\ndef _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):\n    if original_is_causal and num_query_tokens != num_key_tokens:\n        if num_query_tokens != 1:\n            raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')\n        else:\n            return False\n    return original_is_causal\n\ndef scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):\n    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)\n    kv_n_heads = 1 if multiquery else n_heads\n    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)\n    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            k = torch.cat([past_key_value[0], k], dim=3)\n            v = torch.cat([past_key_value[1], v], dim=2)\n        past_key_value = (k, v)\n    (b, _, s_q, d) = q.shape\n    s_k = k.size(-1)\n    if softmax_scale is None:\n        softmax_scale = 1 / math.sqrt(d)\n    attn_weight = q.matmul(k) * softmax_scale\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - s_q)\n        _s_k = max(0, attn_bias.size(3) - s_k)\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n        if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):\n            raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')\n        attn_weight = attn_weight + attn_bias\n    min_val = torch.finfo(q.dtype).min\n    if key_padding_mask is not None:\n        if attn_bias is not None:\n            warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')\n        attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)\n    if is_causal and (not q.size(2) == 1):\n        s = max(s_q, s_k)\n        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)\n        causal_mask = causal_mask.tril()\n        causal_mask = causal_mask.to(torch.bool)\n        causal_mask = ~causal_mask\n        causal_mask = causal_mask[-s_q:, -s_k:]\n        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    if dropout_p:\n        attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)\n    out = attn_weight.to(v.dtype).matmul(v)\n    out = rearrange(out, 'b h s d -> b s (h d)')\n    if needs_weights:\n        return (out, attn_weight, past_key_value)\n    return (out, None, past_key_value)\n\ndef check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):\n    for tensor in tensors:\n        if tensor.dtype not in valid_dtypes:\n            raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')\n        if not tensor.is_cuda:\n            raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')\n\ndef flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):\n    try:\n        from flash_attn import bert_padding, flash_attn_interface\n    except:\n        raise RuntimeError('Please install flash-attn==1.0.3.post0')\n    check_valid_inputs(query, key, value)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            key = torch.cat([past_key_value[0], key], dim=1)\n            value = torch.cat([past_key_value[1], value], dim=1)\n        past_key_value = (key, value)\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - query.size(1))\n        _s_k = max(0, attn_bias.size(3) - key.size(1))\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n    if attn_bias is not None:\n        raise NotImplementedError(f'attn_bias not implemented for flash attn.')\n    (batch_size, seqlen) = query.shape[:2]\n    if key_padding_mask is None:\n        key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)\n    query_padding_mask = key_padding_mask[:, -query.size(1):]\n    (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)\n    query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)\n    (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)\n    key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)\n    (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)\n    value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)\n    if multiquery:\n        key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))\n        value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))\n    dropout_p = dropout_p if training else 0.0\n    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)\n    output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)\n    output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)\n    return (output, None, past_key_value)\n\ndef triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):\n    try:\n        from .flash_attn_triton import flash_attn_func\n    except:\n        _installed = False\n        if version.parse(torch.__version__) < version.parse('2.0.0'):\n            _installed = True\n            try:\n                from flash_attn.flash_attn_triton import flash_attn_func\n            except:\n                _installed = False\n        if not _installed:\n            raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')\n    check_valid_inputs(query, key, value)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            key = torch.cat([past_key_value[0], key], dim=1)\n            value = torch.cat([past_key_value[1], value], dim=1)\n        past_key_value = (key, value)\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - query.size(1))\n        _s_k = max(0, attn_bias.size(3) - key.size(1))\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n    if dropout_p:\n        raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')\n    if needs_weights:\n        raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')\n    if key_padding_mask is not None:\n        warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')\n        (b_size, s_k) = key_padding_mask.shape[:2]\n        if attn_bias is None:\n            attn_bias = query.new_zeros(b_size, 1, 1, s_k)\n        attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)\n    query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)\n    key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)\n    value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)\n    if multiquery:\n        key = key.expand(*key.shape[:2], n_heads, key.size(-1))\n        value = value.expand(*value.shape[:2], n_heads, value.size(-1))\n    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)\n    attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)\n    output = attn_output.view(*attn_output.shape[:2], -1)\n    return (output, None, past_key_value)\n\nclass MultiheadAttention(nn.Module):\n    \"\"\"Multi-head self attention.\n\n    Using torch or triton attention implemetation enables user to also use\n    additive bias.\n    \"\"\"\n\n    def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):\n        super().__init__()\n        self.attn_impl = attn_impl\n        self.clip_qkv = clip_qkv\n        self.qk_ln = qk_ln\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.softmax_scale = softmax_scale\n        if self.softmax_scale is None:\n            self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)\n        self.attn_dropout_p = attn_pdrop\n        self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)\n        fuse_splits = (d_model, 2 * d_model)\n        self.Wqkv._fused = (0, fuse_splits)\n        if self.qk_ln:\n            layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm\n            self.q_ln = layernorm_class(self.d_model, device=device)\n            self.k_ln = layernorm_class(self.d_model, device=device)\n        if self.attn_impl == 'flash':\n            self.attn_fn = flash_attn_fn\n        elif self.attn_impl == 'triton':\n            self.attn_fn = triton_flash_attn_fn\n            if verbose:\n                warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')\n        elif self.attn_impl == 'torch':\n            self.attn_fn = scaled_multihead_dot_product_attention\n            if torch.cuda.is_available() and verbose:\n                warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')\n        else:\n            raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')\n        self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)\n        self.out_proj._is_residual = True\n\n    def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):\n        qkv = self.Wqkv(x)\n        if self.clip_qkv:\n            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)\n        (query, key, value) = qkv.chunk(3, dim=2)\n        key_padding_mask = attention_mask\n        if self.qk_ln:\n            dtype = query.dtype\n            query = self.q_ln(query).to(dtype)\n            key = self.k_ln(key).to(dtype)\n        (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)\n        return (self.out_proj(context), attn_weights, past_key_value)\n\nclass MultiQueryAttention(nn.Module):\n    \"\"\"Multi-Query self attention.\n\n    Using torch or triton attention implemetation enables user to also use\n    additive bias.\n    \"\"\"\n\n    def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):\n        super().__init__()\n        self.attn_impl = attn_impl\n        self.clip_qkv = clip_qkv\n        self.qk_ln = qk_ln\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.head_dim = d_model // n_heads\n        self.softmax_scale = softmax_scale\n        if self.softmax_scale is None:\n            self.softmax_scale = 1 / math.sqrt(self.head_dim)\n        self.attn_dropout_p = attn_pdrop\n        self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)\n        fuse_splits = (d_model, d_model + self.head_dim)\n        self.Wqkv._fused = (0, fuse_splits)\n        if self.qk_ln:\n            layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm\n            self.q_ln = layernorm_class(d_model, device=device)\n            self.k_ln = layernorm_class(self.head_dim, device=device)\n        if self.attn_impl == 'flash':\n            self.attn_fn = flash_attn_fn\n        elif self.attn_impl == 'triton':\n            self.attn_fn = triton_flash_attn_fn\n            if verbose:\n                warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')\n        elif self.attn_impl == 'torch':\n            self.attn_fn = scaled_multihead_dot_product_attention\n            if torch.cuda.is_available() and verbose:\n                warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')\n        else:\n            raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')\n        self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)\n        self.out_proj._is_residual = True\n\n    def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):\n        qkv = self.Wqkv(x)\n        if self.clip_qkv:\n            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)\n        (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)\n        key_padding_mask = attention_mask\n        if self.qk_ln:\n            dtype = query.dtype\n            query = self.q_ln(query).to(dtype)\n            key = self.k_ln(key).to(dtype)\n        (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)\n        return (self.out_proj(context), attn_weights, past_key_value)\n\ndef attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):\n    if attn_impl == 'flash':\n        return None\n    elif attn_impl in ['torch', 'triton']:\n        if alibi:\n            if (prefix_lm or not causal) or use_sequence_id:\n                return (1, n_heads, seq_len, seq_len)\n            return (1, n_heads, 1, seq_len)\n        elif prefix_lm or use_sequence_id:\n            return (1, 1, seq_len, seq_len)\n        return None\n    else:\n        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')\n\ndef build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):\n    if attn_impl == 'flash':\n        return None\n    elif attn_impl in ['torch', 'triton']:\n        if alibi:\n            (device, dtype) = (attn_bias.device, attn_bias.dtype)\n            attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))\n        return attn_bias\n    else:\n        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')\n\ndef gen_slopes(n_heads, alibi_bias_max=8, device=None):\n    _n_heads = 2 ** math.ceil(math.log2(n_heads))\n    m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)\n    m = m.mul(alibi_bias_max / _n_heads)\n    slopes = 1.0 / torch.pow(2, m)\n    if _n_heads != n_heads:\n        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]\n    return slopes.view(1, n_heads, 1, 1)\n\ndef build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):\n    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)\n    if full:\n        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)\n        alibi_bias = alibi_bias.abs().mul(-1)\n    slopes = gen_slopes(n_heads, alibi_bias_max, device=device)\n    alibi_bias = alibi_bias * slopes\n    return alibi_bias.to(dtype=dtype)\nATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}"
  },
  {
    "path": "llava/model/language_model/mpt/blocks.py",
    "content": "\"\"\"GPT Blocks used for the GPT Model.\"\"\"\nfrom typing import Dict, Optional, Tuple\nimport torch\nimport torch.nn as nn\nfrom .attention import ATTN_CLASS_REGISTRY\nfrom .norm import NORM_CLASS_REGISTRY\n\nclass MPTMLP(nn.Module):\n\n    def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):\n        super().__init__()\n        self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)\n        self.act = nn.GELU(approximate='none')\n        self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)\n        self.down_proj._is_residual = True\n\n    def forward(self, x):\n        return self.down_proj(self.act(self.up_proj(x)))\n\nclass MPTBlock(nn.Module):\n\n    def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):\n        del kwargs\n        super().__init__()\n        norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]\n        attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]\n        self.norm_1 = norm_class(d_model, device=device)\n        self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)\n        self.norm_2 = norm_class(d_model, device=device)\n        self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)\n        self.resid_attn_dropout = nn.Dropout(resid_pdrop)\n        self.resid_ffn_dropout = nn.Dropout(resid_pdrop)\n\n    def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:\n        a = self.norm_1(x)\n        (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)\n        x = x + self.resid_attn_dropout(b)\n        m = self.norm_2(x)\n        n = self.ffn(m)\n        x = x + self.resid_ffn_dropout(n)\n        return (x, attn_weights, past_key_value)"
  },
  {
    "path": "llava/model/language_model/mpt/configuration_mpt.py",
    "content": "\"\"\"A HuggingFace-style model configuration.\"\"\"\nfrom typing import Dict, Optional, Union\nfrom transformers import PretrainedConfig\nattn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}\ninit_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}\n\nclass MPTConfig(PretrainedConfig):\n    model_type = 'mpt'\n\n    def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):\n        \"\"\"The MPT configuration class.\n\n        Args:\n            d_model (int): The size of the embedding dimension of the model.\n            n_heads (int): The number of attention heads.\n            n_layers (int): The number of layers in the model.\n            expansion_ratio (int): The ratio of the up/down scale in the MLP.\n            max_seq_len (int): The maximum sequence length of the model.\n            vocab_size (int): The size of the vocabulary.\n            resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.\n            emb_pdrop (float): The dropout probability for the embedding layer.\n            learned_pos_emb (bool): Whether to use learned positional embeddings\n            attn_config (Dict):  A dictionary used to configure the model's attention module:\n                attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention\n                attn_pdrop (float): The dropout probability for the attention layers.\n                attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.\n                qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.\n                clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to\n                    this value.\n                softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,\n                    use the default scale of ``1/sqrt(d_keys)``.\n                prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an\n                    extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix\n                    can attend to one another bi-directionally. Tokens outside the prefix use causal attention.\n                attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.\n                    When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates\n                    which sub-sequence each token belongs to.\n                    Defaults to ``False`` meaning any provided `sequence_id` will be ignored.\n                alibi (bool): Whether to use the alibi bias instead of position embeddings.\n                alibi_bias_max (int): The maximum value of the alibi bias.\n            init_device (str): The device to use for parameter initialization.\n            logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.\n            no_bias (bool): Whether to use bias in all layers.\n            verbose (int): The verbosity level. 0 is silent.\n            embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.\n            norm_type (str): choose type of norm to use\n            multiquery_attention (bool): Whether to use multiquery attention implementation.\n            use_cache (bool): Whether or not the model should return the last key/values attentions\n            init_config (Dict): A dictionary used to configure the model initialization:\n                init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',\n                    'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or\n                    'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.\n                init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.\n                emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.\n                emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution\n                    used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.\n                init_std (float): The standard deviation of the normal distribution used to initialize the model,\n                    if using the baseline_ parameter initialization scheme.\n                init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.\n                fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.\n                init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.\n                ---\n                See llmfoundry.models.utils.param_init_fns.py for info on other param init config options\n        \"\"\"\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.expansion_ratio = expansion_ratio\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.resid_pdrop = resid_pdrop\n        self.emb_pdrop = emb_pdrop\n        self.learned_pos_emb = learned_pos_emb\n        self.attn_config = attn_config\n        self.init_device = init_device\n        self.logit_scale = logit_scale\n        self.no_bias = no_bias\n        self.verbose = verbose\n        self.embedding_fraction = embedding_fraction\n        self.norm_type = norm_type\n        self.use_cache = use_cache\n        self.init_config = init_config\n        if 'name' in kwargs:\n            del kwargs['name']\n        if 'loss_fn' in kwargs:\n            del kwargs['loss_fn']\n        super().__init__(**kwargs)\n        self._validate_config()\n\n    def _set_config_defaults(self, config, config_defaults):\n        for (k, v) in config_defaults.items():\n            if k not in config:\n                config[k] = v\n        return config\n\n    def _validate_config(self):\n        self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)\n        self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)\n        if self.d_model % self.n_heads != 0:\n            raise ValueError('d_model must be divisible by n_heads')\n        if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):\n            raise ValueError(\"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1\")\n        if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:\n            raise ValueError(f\"Unknown attn_impl={self.attn_config['attn_impl']}\")\n        if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:\n            raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')\n        if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:\n            raise NotImplementedError('alibi only implemented with torch and triton attention.')\n        if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:\n            raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')\n        if self.embedding_fraction > 1 or self.embedding_fraction <= 0:\n            raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')\n        if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':\n            raise ValueError(f\"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.\")\n        if self.init_config.get('name', None) is None:\n            raise ValueError(f\"self.init_config={self.init_config!r} 'name' needs to be set.\")\n        if not self.learned_pos_emb and (not self.attn_config['alibi']):\n            raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')"
  },
  {
    "path": "llava/model/language_model/mpt/custom_embedding.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nclass SharedEmbedding(nn.Embedding):\n\n    def forward(self, input: Tensor, unembed: bool=False) -> Tensor:\n        if unembed:\n            return F.linear(input, self.weight)\n        return super().forward(input)"
  },
  {
    "path": "llava/model/language_model/mpt/flash_attn_triton.py",
    "content": "\"\"\"\nCopied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py\nupdate imports to use 'triton_pre_mlir'\n\n*Experimental* implementation of FlashAttention in Triton.\nTested with triton==2.0.0.dev20221202.\nTriton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions\nother than 64:\nhttps://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207\nWe'll update this implementation with the new Triton backend once this is fixed.\n\nWe use the FlashAttention implementation from Phil Tillet a starting point.\nhttps://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py\n\nChanges:\n- Implement both causal and non-causal attention.\n- Implement both self-attention and cross-attention.\n- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.\n- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.\n- Support attention bias.\n- Speed up the forward pass a bit, and only store the LSE instead of m and l.\n- Make the backward for d=128 much faster by reducing register spilling.\n- Optionally parallelize the backward pass across seqlen_k, to deal with the case of\nsmall batch size * nheads.\n\nCaution:\n- This is an *experimental* implementation. The forward pass should be quite robust but\nI'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).\n- This implementation has only been tested on A100.\n- If you plan to use headdim other than 64 and 128, you should test for race conditions\n(due to the Triton compiler), as done in tests/test_flash_attn.py\n\"test_flash_attn_triton_race_condition\". I've tested and fixed many race conditions\nfor different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident\nthat there are none left for other head dimensions.\n\nDifferences between this Triton version and the CUDA version:\n- Triton version doesn't support dropout.\n- Triton forward is generally faster than CUDA forward, while Triton backward is\ngenerally slower than CUDA backward. Overall Triton forward + backward is slightly slower\nthan CUDA forward + backward.\n- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).\n- Triton version supports attention bias, while CUDA version doesn't.\n\"\"\"\nimport math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n    start_m = tl.program_id(0)\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n    if BIAS_TYPE == 'vector':\n        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n    elif BIAS_TYPE == 'matrix':\n        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n    if EVEN_M & EVEN_N:\n        if EVEN_HEADDIM:\n            q = tl.load(q_ptrs)\n        else:\n            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n    elif EVEN_HEADDIM:\n        q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n    else:\n        q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n    for start_n in range(0, end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        if EVEN_N & EVEN_M:\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs + start_n * stride_kn)\n            else:\n                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n        elif EVEN_HEADDIM:\n            k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n        else:\n            k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k, trans_b=True)\n        if not EVEN_N:\n            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf'))\n        if IS_CAUSAL:\n            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf'))\n        if BIAS_TYPE != 'none':\n            if BIAS_TYPE == 'vector':\n                if EVEN_N:\n                    bias = tl.load(b_ptrs + start_n).to(tl.float32)\n                else:\n                    bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)\n                bias = bias[None, :]\n            elif BIAS_TYPE == 'matrix':\n                if EVEN_M & EVEN_N:\n                    bias = tl.load(b_ptrs + start_n).to(tl.float32)\n                else:\n                    bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)\n            qk = qk * softmax_scale + bias\n            m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n            p = tl.exp(qk - m_ij[:, None])\n        else:\n            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n            p = tl.exp(qk * softmax_scale - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n        acc_o_scale = tl.exp(m_i - m_ij)\n        tl.store(t_ptrs, acc_o_scale)\n        acc_o_scale = tl.load(t_ptrs)\n        acc_o = acc_o * acc_o_scale[:, None]\n        if EVEN_N & EVEN_M:\n            if EVEN_HEADDIM:\n                v = tl.load(v_ptrs + start_n * stride_vn)\n            else:\n                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n        elif EVEN_HEADDIM:\n            v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n        else:\n            v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n        p = p.to(v.dtype)\n        acc_o += tl.dot(p, v)\n        m_i = m_ij\n        l_i_new = tl.exp(lse_i - m_ij) + l_ij\n        lse_i = m_ij + tl.log(l_i_new)\n    o_scale = tl.exp(m_i - lse_i)\n    tl.store(t_ptrs, o_scale)\n    o_scale = tl.load(t_ptrs)\n    acc_o = acc_o * o_scale[:, None]\n    start_m = tl.program_id(0)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n    tl.store(lse_ptrs, lse_i)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n    if EVEN_M:\n        if EVEN_HEADDIM:\n            tl.store(out_ptrs, acc_o)\n        else:\n            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n    elif EVEN_HEADDIM:\n        tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n    else:\n        tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr):\n    start_m = tl.program_id(0)\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)\n    do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)\n    delta = tl.sum(o * do, axis=1)\n    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)\n\n@triton.jit\ndef _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):\n    if EVEN_N & EVEN_M:\n        if EVEN_HEADDIM:\n            tl.store(dv_ptrs, dv)\n            tl.store(dk_ptrs, dk)\n        else:\n            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)\n            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)\n    elif EVEN_HEADDIM:\n        tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)\n        tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)\n    else:\n        tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))\n        tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))\n\n@triton.jit\ndef _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n    begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M\n    offs_qm = begin_m + tl.arange(0, BLOCK_M)\n    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    offs_m = tl.arange(0, BLOCK_M)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])\n    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])\n    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])\n    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])\n    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])\n    if BIAS_TYPE == 'vector':\n        b_ptrs = Bias + offs_n\n    elif BIAS_TYPE == 'matrix':\n        b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])\n    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n    if begin_m >= seqlen_q:\n        dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n        dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n        _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)\n        return\n    if EVEN_N & EVEN_M:\n        if EVEN_HEADDIM:\n            k = tl.load(k_ptrs)\n            v = tl.load(v_ptrs)\n        else:\n            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n    elif EVEN_HEADDIM:\n        k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n        v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n    else:\n        k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n        v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)\n    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):\n        start_m = tl.multiple_of(start_m, BLOCK_M)\n        offs_m_curr = start_m + offs_m\n        if EVEN_M & EVEN_HEADDIM:\n            q = tl.load(q_ptrs)\n        elif EVEN_HEADDIM:\n            q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)\n        else:\n            q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n        qk = tl.dot(q, k, trans_b=True)\n        if not EVEN_N:\n            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))\n        if IS_CAUSAL:\n            qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float('-inf'))\n        if BIAS_TYPE != 'none':\n            tl.debug_barrier()\n            if BIAS_TYPE == 'vector':\n                if EVEN_N:\n                    bias = tl.load(b_ptrs).to(tl.float32)\n                else:\n                    bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)\n                bias = bias[None, :]\n            elif BIAS_TYPE == 'matrix':\n                if EVEN_M & EVEN_N:\n                    bias = tl.load(b_ptrs).to(tl.float32)\n                else:\n                    bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)\n            qk = qk * softmax_scale + bias\n        if not EVEN_M & EVEN_HEADDIM:\n            tl.debug_barrier()\n        lse_i = tl.load(LSE + offs_m_curr)\n        if BIAS_TYPE == 'none':\n            p = tl.exp(qk * softmax_scale - lse_i[:, None])\n        else:\n            p = tl.exp(qk - lse_i[:, None])\n        if EVEN_M & EVEN_HEADDIM:\n            do = tl.load(do_ptrs)\n        else:\n            do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n        dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n        if not EVEN_M & EVEN_HEADDIM:\n            tl.debug_barrier()\n        dp = tl.dot(do, v, trans_b=True)\n        if not EVEN_HEADDIM:\n            tl.debug_barrier()\n        Di = tl.load(D + offs_m_curr)\n        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)\n        dk += tl.dot(ds, q, trans_a=True)\n        if not EVEN_M & EVEN_HEADDIM:\n            tl.debug_barrier()\n        if not ATOMIC_ADD:\n            if EVEN_M & EVEN_HEADDIM:\n                dq = tl.load(dq_ptrs, eviction_policy='evict_last')\n                dq += tl.dot(ds, k)\n                tl.store(dq_ptrs, dq, eviction_policy='evict_last')\n            elif EVEN_HEADDIM:\n                dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy='evict_last')\n                dq += tl.dot(ds, k)\n                tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy='evict_last')\n            else:\n                dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy='evict_last')\n                dq += tl.dot(ds, k)\n                tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy='evict_last')\n        else:\n            dq = tl.dot(ds, k)\n            if EVEN_M & EVEN_HEADDIM:\n                tl.atomic_add(dq_ptrs, dq)\n            elif EVEN_HEADDIM:\n                tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)\n            else:\n                tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n        dq_ptrs += BLOCK_M * stride_dqm\n        q_ptrs += BLOCK_M * stride_qm\n        do_ptrs += BLOCK_M * stride_dom\n        if BIAS_TYPE == 'matrix':\n            b_ptrs += BLOCK_M * stride_bm\n    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n    _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)\n\ndef init_to_zero(name):\n    return lambda nargs: nargs[name].zero_()\n\n@triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    Q += off_b * stride_qb + off_h * stride_qh\n    K += off_b * stride_kb + off_h * stride_kh\n    V += off_b * stride_vb + off_h * stride_vh\n    DO += off_b * stride_dob + off_h * stride_doh\n    DQ += off_b * stride_dqb + off_h * stride_dqh\n    DK += off_b * stride_dkb + off_h * stride_dkh\n    DV += off_b * stride_dvb + off_h * stride_dvh\n    if BIAS_TYPE != 'none':\n        Bias += off_b * stride_bb + off_h * stride_bh\n    D += off_hb * seqlen_q_rounded\n    LSE += off_hb * seqlen_q_rounded\n    if not SEQUENCE_PARALLEL:\n        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)\n        for start_n in range(0, num_block_n):\n            _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)\n    else:\n        start_n = tl.program_id(0)\n        _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n    (batch, seqlen_q, nheads, d) = q.shape\n    (_, seqlen_k, _, _) = k.shape\n    assert k.shape == (batch, seqlen_k, nheads, d)\n    assert v.shape == (batch, seqlen_k, nheads, d)\n    assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n    assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n    assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n    assert q.is_cuda and k.is_cuda and v.is_cuda\n    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n    has_bias = bias is not None\n    bias_type = 'none'\n    if has_bias:\n        assert bias.dtype in [q.dtype, torch.float]\n        assert bias.is_cuda\n        assert bias.dim() == 4\n        if bias.stride(-1) != 1:\n            bias = bias.contiguous()\n        if bias.shape[2:] == (1, seqlen_k):\n            bias_type = 'vector'\n        elif bias.shape[2:] == (seqlen_q, seqlen_k):\n            bias_type = 'matrix'\n        else:\n            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')\n        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n    o = torch.empty_like(q)\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    BLOCK = 128\n    num_warps = 4 if d <= 64 else 8\n    grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)\n    _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1)\n    return (o, lse, softmax_scale)\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n    if do.stride(-1) != 1:\n        do = do.contiguous()\n    (batch, seqlen_q, nheads, d) = q.shape\n    (_, seqlen_k, _, _) = k.shape\n    assert d <= 128\n    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n    assert lse.shape == (batch, nheads, seqlen_q_rounded)\n    assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1\n    assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n    dq_accum = torch.empty_like(q, dtype=torch.float32)\n    delta = torch.empty_like(lse)\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)\n    _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM)\n    has_bias = bias is not None\n    bias_type = 'none'\n    if has_bias:\n        assert bias.dtype in [q.dtype, torch.float]\n        assert bias.is_cuda\n        assert bias.dim() == 4\n        assert bias.stride(-1) == 1\n        if bias.shape[2:] == (1, seqlen_k):\n            bias_type = 'vector'\n        elif bias.shape[2:] == (seqlen_q, seqlen_k):\n            bias_type = 'matrix'\n        else:\n            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')\n        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n    grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1, batch * nheads)\n    _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM)\n    dq.copy_(dq_accum)\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n            qkv: (batch, seqlen, 3, nheads, headdim)\n            bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).\n                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).\n                ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)\n        \"\"\"\n        if qkv.stride(-1) != 1:\n            qkv = qkv.contiguous()\n        (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)\n        ctx.save_for_backward(qkv, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        (qkv, o, lse, bias) = ctx.saved_tensors\n        assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'\n        with torch.inference_mode():\n            dqkv = torch.empty_like(qkv)\n            _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)\n        return (dqkv, None, None, None)\nflash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply\n\nclass FlashAttnKVPackedFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n            q: (batch, seqlen_q, nheads, headdim)\n            kv: (batch, seqlen_k, 2, nheads, headdim)\n            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n        \"\"\"\n        (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]\n        (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)\n        ctx.save_for_backward(q, kv, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        (q, kv, o, lse, bias) = ctx.saved_tensors\n        if len(ctx.needs_input_grad) >= 3:\n            assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'\n        with torch.inference_mode():\n            dq = torch.empty_like(q)\n            dkv = torch.empty_like(kv)\n            _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)\n        return (dq, dkv, None, None, None)\nflash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply\n\nclass FlashAttnFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n            q: (batch_size, seqlen_q, nheads, headdim)\n            k, v: (batch_size, seqlen_k, nheads, headdim)\n            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n        \"\"\"\n        (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]\n        (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)\n        ctx.save_for_backward(q, k, v, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        (q, k, v, o, lse, bias) = ctx.saved_tensors\n        assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'\n        with torch.inference_mode():\n            dq = torch.empty_like(q)\n            dk = torch.empty_like(k)\n            dv = torch.empty_like(v)\n            _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)\n        return (dq, dk, dv, None, None, None)\nflash_attn_func = FlashAttnFunc.apply"
  },
  {
    "path": "llava/model/language_model/mpt/hf_prefixlm_converter.py",
    "content": "\"\"\"Converts Huggingface Causal LM to Prefix LM.\n\nConversion does lightweight surgery on a HuggingFace\nCausal LM to convert it to a Prefix LM.\n\nPrefix LMs accepts a `bidirectional_mask` input in `forward`\nand treat the input prompt as the prefix in `generate`.\n\"\"\"\nimport math\nimport warnings\nfrom types import MethodType\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport torch\nfrom transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss\nfrom transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom\nfrom transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom\nfrom transformers.models.bloom.modeling_bloom import logging\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel\nfrom transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM\nfrom transformers.models.gptj.modeling_gptj import GPTJForCausalLM\nfrom transformers.models.opt.modeling_opt import OPTForCausalLM\nfrom transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt\nfrom transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt\nlogger = logging.get_logger(__name__)\n_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)\nCAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]\n\ndef _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:\n    \"\"\"Converts a GPT-style Causal LM to a Prefix LM.\n\n    Supported HuggingFace model classes:\n        - `GPT2LMHeadModel`\n        - `GPTNeoForCausalLM`\n        - `GPTNeoXForCausalLM`\n        - `GPTJForCausalLM`\n\n    See `convert_hf_causal_lm_to_prefix_lm` for more details.\n    \"\"\"\n    if hasattr(model, '_prefix_lm_converted'):\n        return model\n    assert isinstance(model, _SUPPORTED_GPT_MODELS)\n    assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'\n\n    def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:\n        \"\"\"Helper that gets a list of the model's attention modules.\n\n        Each module has a `bias` buffer used for causal masking. The Prefix LM\n        conversion adds logic to dynamically manipulate these biases to support\n        Prefix LM attention masking.\n        \"\"\"\n        attn_modules = []\n        if isinstance(model, GPTNeoXForCausalLM):\n            blocks = model.gpt_neox.layers\n        else:\n            blocks = model.transformer.h\n        for block in blocks:\n            if isinstance(model, GPTNeoForCausalLM):\n                if block.attn.attention_type != 'global':\n                    continue\n                attn_module = block.attn.attention\n            elif isinstance(model, GPTNeoXForCausalLM):\n                attn_module = block.attention\n            else:\n                attn_module = block.attn\n            attn_modules.append(attn_module)\n        return attn_modules\n    setattr(model, '_original_forward', getattr(model, 'forward'))\n    setattr(model, '_original_generate', getattr(model, 'generate'))\n\n    def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):\n        \"\"\"Wraps original forward to enable PrefixLM attention.\"\"\"\n\n        def call_og_forward():\n            if isinstance(self, GPTNeoXForCausalLM):\n                return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)\n            else:\n                return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)\n        if bidirectional_mask is None:\n            return call_og_forward()\n        assert isinstance(bidirectional_mask, torch.Tensor)\n        attn_modules = _get_attn_modules(model)\n        (b, s) = bidirectional_mask.shape\n        max_length = attn_modules[0].bias.shape[-1]\n        if s > max_length:\n            raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')\n        assert s <= max_length\n        if s < max_length:\n            pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)\n            bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)\n        bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)\n        for attn_module in attn_modules:\n            attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)\n        output = call_og_forward()\n        for attn_module in attn_modules:\n            attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]\n        return output\n\n    def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):\n        \"\"\"Wraps original generate to enable PrefixLM attention.\"\"\"\n        attn_modules = _get_attn_modules(model)\n        for attn_module in attn_modules:\n            attn_module.bias.data[:] = 1\n        output = self._original_generate(*args, **kwargs)\n        for attn_module in attn_modules:\n            attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]\n        return output\n    setattr(model, 'forward', MethodType(forward, model))\n    setattr(model, 'generate', MethodType(generate, model))\n    setattr(model, '_prefix_lm_converted', True)\n    return model\n\ndef _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:\n    \"\"\"Converts a BLOOM Causal LM to a Prefix LM.\n\n    Supported HuggingFace model classes:\n        - `BloomForCausalLM`\n\n    See `convert_hf_causal_lm_to_prefix_lm` for more details.\n    \"\"\"\n    if hasattr(model, '_prefix_lm_converted'):\n        return model\n    assert isinstance(model, BloomForCausalLM)\n    assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'\n\n    def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:\n        combined_attention_mask = None\n        device = attention_mask.device\n        (_, src_length) = input_shape\n        if src_length > 1:\n            combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)\n            if bidirectional_mask is not None:\n                assert attention_mask.shape == bidirectional_mask.shape\n                expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)\n                combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)\n        expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)\n        combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask\n        return combined_attention_mask\n\n    def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:\n        num_heads = self.config.n_head\n        closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))\n        base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)\n        powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)\n        slopes = torch.pow(base, powers)\n        if closest_power_of_2 != num_heads:\n            extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)\n            num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)\n            extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)\n            slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)\n        qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)\n        ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)\n        diffs = qa - ka + key_length - query_length\n        diffs = -diffs.abs()\n        alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)\n        alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)\n        return alibi.to(dtype)\n    KeyValueT = Tuple[torch.Tensor, torch.Tensor]\n\n    def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:\n        if deprecated_arguments.pop('position_ids', False) is not False:\n            warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')\n        elif input_ids is not None:\n            (batch_size, seq_length) = input_ids.shape\n        elif inputs_embeds is not None:\n            (batch_size, seq_length, _) = inputs_embeds.shape\n        else:\n            raise ValueError('You have to specify either input_ids or inputs_embeds')\n        if past_key_values is None:\n            past_key_values = tuple([None] * len(self.h))\n        head_mask = self.get_head_mask(head_mask, self.config.n_layer)\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values[0] is not None:\n            tmp = past_key_values[0][0]\n            past_key_values_length = tmp.shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n        if attention_mask is None:\n            attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)\n        else:\n            attention_mask = attention_mask.to(hidden_states.device)\n        alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)\n        causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)\n        for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):\n            if output_hidden_states:\n                hst = (hidden_states,)\n                all_hidden_states = all_hidden_states + hst\n            if self.gradient_checkpointing and self.training:\n                if use_cache:\n                    logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')\n                    use_cache = False\n\n                def create_custom_forward(module):\n\n                    def custom_forward(*inputs):\n                        return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)\n                    return custom_forward\n                outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])\n            else:\n                outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n            if output_attentions:\n                oa = (outputs[2 if use_cache else 1],)\n                all_self_attentions = all_self_attentions + oa\n        hidden_states = self.ln_f(hidden_states)\n        if output_hidden_states:\n            hst = (hidden_states,)\n            all_hidden_states = all_hidden_states + hst\n        if not return_dict:\n            return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))\n        return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)\n    setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))\n    setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))\n    setattr(model.transformer, 'forward', MethodType(forward, model.transformer))\n    KeyValueT = Tuple[torch.Tensor, torch.Tensor]\n\n    def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:\n        \"\"\"Replacement forward method for BloomCausalLM.\"\"\"\n        if deprecated_arguments.pop('position_ids', False) is not False:\n            warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)\n        hidden_states = transformer_outputs[0]\n        lm_logits = self.lm_head(hidden_states)\n        loss = None\n        if labels is not None:\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            (batch_size, seq_length, vocab_size) = shift_logits.shape\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return (loss,) + output if loss is not None else output\n        return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)\n\n    def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:\n        if past:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n            bidirectional_mask = None\n            if past[0][0].shape[0] == input_ids.shape[0]:\n                past = self._convert_to_bloom_cache(past)\n        else:\n            bidirectional_mask = torch.ones_like(input_ids)\n        return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}\n    setattr(model, 'forward', MethodType(forward, model))\n    setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))\n    setattr(model, '_prefix_lm_converted', True)\n    return model\n\ndef _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:\n    \"\"\"Converts an OPT Causal LM to a Prefix LM.\n\n    Supported HuggingFace model classes:\n        - `OPTForCausalLM`\n\n    See `convert_hf_causal_lm_to_prefix_lm` for more details.\n    \"\"\"\n    if hasattr(model, '_prefix_lm_converted'):\n        return model\n    assert isinstance(model, OPTForCausalLM)\n    assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'\n    setattr(model, '_original_forward', getattr(model, 'forward'))\n    setattr(model, '_original_generate', getattr(model, 'generate'))\n    model.model.decoder.bidirectional_mask = None\n\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            if self.bidirectional_mask == 'g':\n                (bsz, src_length) = input_shape\n                combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)\n            else:\n                combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)\n                if self.bidirectional_mask is not None:\n                    assert attention_mask.shape == self.bidirectional_mask.shape\n                    expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)\n                    combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)\n        if attention_mask is not None:\n            expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)\n            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        return combined_attention_mask\n    setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))\n\n    def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):\n\n        def call_og_forward():\n            return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)\n        if bidirectional_mask is None:\n            return call_og_forward()\n        self.model.decoder.bidirectional_mask = bidirectional_mask\n        try:\n            outputs = call_og_forward()\n        except:\n            self.model.decoder.bidirectional_mask = None\n            raise\n        self.model.decoder.bidirectional_mask = None\n        return outputs\n\n    def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):\n        \"\"\"Wraps original generate to enable PrefixLM-style attention.\"\"\"\n        self.model.decoder.bidirectional_mask = 'g'\n        try:\n            output = self._original_generate(*args, **kwargs)\n        except:\n            self.model.decoder.bidirectional_mask = None\n            raise\n        self.model.decoder.bidirectional_mask = None\n        return output\n    setattr(model, 'forward', MethodType(forward, model))\n    setattr(model, 'generate', MethodType(generate, model))\n    setattr(model, '_prefix_lm_converted', True)\n    return model\n_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)\nCAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]\n\ndef convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:\n    \"\"\"Converts a HuggingFace Causal LM to a Prefix LM.\n\n    Supported HuggingFace model classes:\n        - `GPT2LMHeadModel`\n        - `GPTNeoForCausalLM`\n        - `GPTNeoXForCausalLM`\n        - `GPTJForCausalLM`\n        - `BloomForCausalLM`\n        - `OPTForCausalLM`\n\n    Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the\n    `generate` method and/or select underlying methods depending on the model class.\n\n    These changes preserve the model API, but add a new input to `forward`: \"bidirectional_mask\".\n\n    Notes on training:\n        To actually train the converted model as a Prefix LM, training batches will need to indicate\n        the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.\n\n        **This is not a standard input and requires custom layers either within or after your dataloader.**\n\n        In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`\n        such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.\n        That is, the prefix portion of the sequence should not generate any loss. Loss should only be\n        generated by the target portion of the sequence.\n\n    Notes on `GPTNeoForCausalLM`:\n        To simplify the implementation, \"global\" and \"local\" attention layers are handled differently.\n        For \"global\" layers, we handle conversion as described above. For \"local\" layers, which use a\n        causal attention mask within a restricted local window, we do not alter the masking.\n\n    Notes on `forward` method conversion:\n        After conversion, the `forward` method will handle a new input, `bidirectional_mask`,\n        which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions\n        belonging to the prefix (prefix tokens can attend to one another bidirectionally), and\n        0 indicates token positions belonging to the target.\n\n        The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing\n        causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset\n        the causal masks before returning the result.\n\n    Notes on `generate` method conversion:\n        After conversion, the `generate` method will have the same signature but will internally\n        convert all causal masks to be purely bidirectional, call the original `generate` method, and\n        (where appropriate) reset the causal masks before returning the result.\n\n        This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token\n        \"prompt\" passed to `generate` (which is treated as the prefix) and then sequentially generates\n        each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one\n        another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and\n        previously-generated tokens (also as expected in a Prefix LM).\n\n    To preserve the API, the original methods are renamed to `_original_forward` and\n    `_original_generate`, and replaced with new `forward` and `generate` methods that wrap\n    them, respectively. Although implementation details vary by model class.\n    \"\"\"\n    if isinstance(model, _SUPPORTED_GPT_MODELS):\n        return _convert_gpt_causal_lm_to_prefix_lm(model)\n    elif isinstance(model, BloomForCausalLM):\n        return _convert_bloom_causal_lm_to_prefix_lm(model)\n    elif isinstance(model, OPTForCausalLM):\n        return _convert_opt_causal_lm_to_prefix_lm(model)\n    else:\n        raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\\n{_SUPPORTED_HF_MODELS}')\n\ndef add_bidirectional_mask_if_missing(batch: Dict[str, Any]):\n    \"\"\"Attempts to add bidirectional_mask to batch if missing.\n\n    Raises:\n        KeyError if bidirectional_mask is missing and can't be inferred\n    \"\"\"\n    if 'bidirectional_mask' not in batch:\n        if batch.get('mode', None) == 'icl_task':\n            batch['bidirectional_mask'] = batch['attention_mask'].clone()\n            for (i, continuation_indices) in enumerate(batch['continuation_indices']):\n                batch['bidirectional_mask'][i, continuation_indices] = 0\n        elif 'labels' in batch and 'attention_mask' in batch:\n            batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])\n        else:\n            raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')"
  },
  {
    "path": "llava/model/language_model/mpt/meta_init_context.py",
    "content": "from contextlib import contextmanager\nimport torch\nimport torch.nn as nn\n\n@contextmanager\ndef init_empty_weights(include_buffers: bool=False):\n    \"\"\"Meta initialization context manager.\n\n    A context manager under which models are initialized with all parameters\n    on the meta device, therefore creating an empty model. Useful when just\n    initializing the model would blow the available RAM.\n\n    Args:\n        include_buffers (`bool`, *optional*, defaults to `False`): Whether or\n            not to also put all buffers on the meta device while initializing.\n\n    Example:\n    ```python\n    import torch.nn as nn\n\n    # Initialize a model with 100 billions parameters in no time and without using any RAM.\n    with init_empty_weights():\n        tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])\n    ```\n\n    <Tip warning={true}>\n\n    Any model created under this context manager has no weights. As such you can't do something like\n    `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].\n\n    </Tip>\n    \"\"\"\n    with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:\n        yield f\n\n@contextmanager\ndef init_on_device(device: torch.device, include_buffers: bool=False):\n    \"\"\"Device initialization context manager.\n\n    A context manager under which models are initialized with all parameters\n    on the specified device.\n\n    Args:\n        device (`torch.device`): Device to initialize all parameters on.\n        include_buffers (`bool`, *optional*, defaults to `False`): Whether or\n            not to also put all buffers on the meta device while initializing.\n\n    Example:\n    ```python\n    import torch.nn as nn\n\n    with init_on_device(device=torch.device(\"cuda\")):\n        tst = nn.Liner(100, 100)  # on `cuda` device\n    ```\n    \"\"\"\n    old_register_parameter = nn.Module.register_parameter\n    if include_buffers:\n        old_register_buffer = nn.Module.register_buffer\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        if param is not None:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n\n    def register_empty_buffer(module, name, buffer):\n        old_register_buffer(module, name, buffer)\n        if buffer is not None:\n            module._buffers[name] = module._buffers[name].to(device)\n    if include_buffers:\n        tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}\n    else:\n        tensor_constructors_to_patch = {}\n\n    def patch_tensor_constructor(fn):\n\n        def wrapper(*args, **kwargs):\n            kwargs['device'] = device\n            return fn(*args, **kwargs)\n        return wrapper\n    try:\n        nn.Module.register_parameter = register_empty_parameter\n        if include_buffers:\n            nn.Module.register_buffer = register_empty_buffer\n        for torch_function_name in tensor_constructors_to_patch.keys():\n            setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))\n        yield\n    finally:\n        nn.Module.register_parameter = old_register_parameter\n        if include_buffers:\n            nn.Module.register_buffer = old_register_buffer\n        for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():\n            setattr(torch, torch_function_name, old_torch_function)"
  },
  {
    "path": "llava/model/language_model/mpt/modeling_mpt.py",
    "content": "\"\"\"A simple, flexible implementation of a GPT model.\n\nInspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py\n\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom .attention import attn_bias_shape, build_attn_bias\nfrom .blocks import MPTBlock\nfrom .custom_embedding import SharedEmbedding\nfrom .norm import NORM_CLASS_REGISTRY\nfrom .configuration_mpt import MPTConfig\nfrom .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising\nfrom .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm\nfrom .meta_init_context import init_empty_weights\nfrom .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_\ntry:\n    from .flash_attn_triton import flash_attn_func\nexcept:\n    pass\nTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\n\nclass MPTPreTrainedModel(PreTrainedModel):\n    config_class = MPTConfig\n    base_model_prefix = 'model'\n    _no_split_modules = ['MPTBlock']\n\nclass MPTModel(MPTPreTrainedModel):\n\n    def __init__(self, config: MPTConfig):\n        config._validate_config()\n        super().__init__(config)\n        self.attn_impl = config.attn_config['attn_impl']\n        self.prefix_lm = config.attn_config['prefix_lm']\n        self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']\n        self.alibi = config.attn_config['alibi']\n        self.alibi_bias_max = config.attn_config['alibi_bias_max']\n        if config.init_device == 'mixed':\n            if dist.get_local_rank() == 0:\n                config.init_device = 'cpu'\n            else:\n                config.init_device = 'meta'\n        if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():\n            norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())\n            raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')\n        norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]\n        self.embedding_fraction = config.embedding_fraction\n        self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)\n        if not self.alibi:\n            self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)\n        self.emb_drop = nn.Dropout(config.emb_pdrop)\n        self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])\n        self.norm_f = norm_class(config.d_model, device=config.init_device)\n        if config.init_device != 'meta':\n            print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device=\"meta\" with Composer + FSDP for fast initialization.')\n            self.apply(self.param_init_fn)\n        self.is_causal = not self.prefix_lm\n        self._attn_bias_initialized = False\n        self.attn_bias = None\n        self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)\n        if config.no_bias:\n            for module in self.modules():\n                if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):\n                    if config.verbose:\n                        warnings.warn(f'Removing bias ({module.bias}) from {module}.')\n                    module.register_parameter('bias', None)\n        if config.verbose and config.verbose > 2:\n            print(self)\n        if 'verbose' not in self.config.init_config:\n            self.config.init_config['verbose'] = self.config.verbose\n        if self.config.init_config['verbose'] > 1:\n            init_fn_name = self.config.init_config['name']\n            warnings.warn(f'Using {init_fn_name} initialization.')\n        self.gradient_checkpointing = False\n\n    def get_input_embeddings(self):\n        return self.wte\n\n    def set_input_embeddings(self, value):\n        self.wte = value\n\n    @torch.no_grad()\n    def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):\n        if not self._attn_bias_initialized:\n            if self.attn_bias_shape:\n                self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)\n                self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)\n            self._attn_bias_initialized = True\n        if self.attn_impl == 'flash':\n            return (self.attn_bias, attention_mask)\n        if self.attn_bias is not None:\n            self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)\n        attn_bias = self.attn_bias\n        if self.prefix_lm:\n            assert isinstance(attn_bias, torch.Tensor)\n            assert isinstance(prefix_mask, torch.Tensor)\n            attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)\n        if self.attn_uses_sequence_id and sequence_id is not None:\n            assert isinstance(attn_bias, torch.Tensor)\n            attn_bias = self._apply_sequence_id(attn_bias, sequence_id)\n        if attention_mask is not None:\n            s_k = attention_mask.shape[-1]\n            if attn_bias is None:\n                attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)\n            else:\n                _s_k = max(0, attn_bias.size(-1) - s_k)\n                attn_bias = attn_bias[:, :, :, _s_k:]\n            if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:\n                raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')\n            min_val = torch.finfo(attn_bias.dtype).min\n            attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)\n        return (attn_bias, None)\n\n    def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):\n        (s_k, s_q) = attn_bias.shape[-2:]\n        if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:\n            raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')\n        seq_len = prefix_mask.shape[-1]\n        if seq_len > self.config.max_seq_len:\n            raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')\n        attn_bias = attn_bias[..., :seq_len, :seq_len]\n        causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)\n        prefix = prefix_mask.view(-1, 1, 1, seq_len)\n        cannot_attend = ~torch.logical_or(causal, prefix.bool())\n        min_val = torch.finfo(attn_bias.dtype).min\n        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)\n        return attn_bias\n\n    def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):\n        seq_len = sequence_id.shape[-1]\n        if seq_len > self.config.max_seq_len:\n            raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')\n        attn_bias = attn_bias[..., :seq_len, :seq_len]\n        cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)\n        min_val = torch.finfo(attn_bias.dtype).min\n        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)\n        return attn_bias\n\n    def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):\n        return_dict = return_dict if return_dict is not None else self.config.return_dict\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        if attention_mask is not None:\n            attention_mask = attention_mask.bool()\n        if prefix_mask is not None:\n            prefix_mask = prefix_mask.bool()\n        if not return_dict:\n            raise NotImplementedError('return_dict False is not implemented yet for MPT')\n        if output_attentions:\n            if self.attn_impl != 'torch':\n                raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')\n        if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:\n            raise NotImplementedError('MPT does not support training with left padding.')\n        if self.prefix_lm and prefix_mask is None:\n            raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')\n        if self.training:\n            if self.attn_uses_sequence_id and sequence_id is None:\n                raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')\n            elif self.attn_uses_sequence_id is False and sequence_id is not None:\n                warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')\n        if input_ids is not None:\n            S = input_ids.size(1)\n            assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'\n            tok_emb = self.wte(input_ids)\n        else:\n            assert inputs_embeds is not None\n            assert self.alibi, 'inputs_embeds is not implemented for MPT unless for alibi.'\n            S = inputs_embeds.size(1)\n            tok_emb = inputs_embeds\n        if self.alibi:\n            x = tok_emb\n        else:\n            past_position = 0\n            if past_key_values is not None:\n                if len(past_key_values) != self.config.n_layers:\n                    raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')\n                past_position = past_key_values[0][0].size(1)\n                if self.attn_impl == 'torch':\n                    past_position = past_key_values[0][0].size(3)\n            if S + past_position > self.config.max_seq_len:\n                raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')\n            pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)\n            if attention_mask is not None:\n                pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)\n            pos_emb = self.wpe(pos)\n            x = tok_emb + pos_emb\n        if self.embedding_fraction == 1:\n            x = self.emb_drop(x)\n        else:\n            x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)\n            assert isinstance(self.emb_drop, nn.Module)\n            x = self.emb_drop(x_shrunk)\n        (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)\n        if use_cache and past_key_values is None:\n            past_key_values = [() for _ in range(self.config.n_layers)]\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        for (b_idx, block) in enumerate(self.blocks):\n            if output_hidden_states:\n                assert all_hidden_states is not None\n                all_hidden_states = all_hidden_states + (x,)\n            past_key_value = past_key_values[b_idx] if past_key_values is not None else None\n            if self.gradient_checkpointing and self.training:\n                (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(block, x, past_key_value, attn_bias, attention_mask, self.is_causal)\n            else:\n                (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)\n            if past_key_values is not None:\n                past_key_values[b_idx] = past_key_value\n            if output_attentions:\n                assert all_self_attns is not None\n                all_self_attns = all_self_attns + (attn_weights,)\n        x = self.norm_f(x)\n        if output_hidden_states:\n            assert all_hidden_states is not None\n            all_hidden_states = all_hidden_states + (x,)\n        return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)\n\n    def param_init_fn(self, module):\n        init_fn_name = self.config.init_config['name']\n        MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)\n\n    def fsdp_wrap_fn(self, module):\n        return isinstance(module, MPTBlock)\n\n    def activation_checkpointing_fn(self, module):\n        return isinstance(module, MPTBlock)\n\nclass MPTForCausalLM(MPTPreTrainedModel):\n\n    def __init__(self, config: MPTConfig):\n        super().__init__(config)\n        if not config.tie_word_embeddings:\n            raise ValueError('MPTForCausalLM only supports tied word embeddings')\n        print(f'Instantiating an MPTForCausalLM model from {__file__}')\n        self.transformer = MPTModel(config)\n        for child in self.transformer.children():\n            if isinstance(child, torch.nn.ModuleList):\n                continue\n            if isinstance(child, torch.nn.Module):\n                child._fsdp_wrap = True\n        self.logit_scale = None\n        if config.logit_scale is not None:\n            logit_scale = config.logit_scale\n            if isinstance(logit_scale, str):\n                if logit_scale == 'inv_sqrt_d_model':\n                    logit_scale = 1 / math.sqrt(config.d_model)\n                else:\n                    raise ValueError(f\"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.\")\n            self.logit_scale = logit_scale\n\n    def get_input_embeddings(self):\n        return self.transformer.wte\n\n    def set_input_embeddings(self, value):\n        self.transformer.wte = value\n\n    def get_output_embeddings(self):\n        return self.transformer.wte\n\n    def set_output_embeddings(self, new_embeddings):\n        self.transformer.wte = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.transformer = decoder\n\n    def get_decoder(self):\n        return self.transformer\n\n    def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):\n        return_dict = return_dict if return_dict is not None else self.config.return_dict\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        if inputs_embeds is not None:\n            raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')\n        outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)\n        logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)\n        if self.logit_scale is not None:\n            if self.logit_scale == 0:\n                warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')\n            logits *= self.logit_scale\n        loss = None\n        if labels is not None:\n            labels = torch.roll(labels, shifts=-1)\n            labels[:, -1] = -100\n            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))\n        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n\n    def param_init_fn(self, module):\n        init_fn_name = self.config.init_config['name']\n        MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)\n\n    def fsdp_wrap_fn(self, module):\n        return isinstance(module, MPTBlock)\n\n    def activation_checkpointing_fn(self, module):\n        return isinstance(module, MPTBlock)\n\n    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):\n        if inputs_embeds is not None:\n            raise NotImplementedError('inputs_embeds is not implemented for MPT yet')\n        attention_mask = kwargs['attention_mask'].bool()\n        if attention_mask[:, -1].sum() != attention_mask.shape[0]:\n            raise NotImplementedError('MPT does not support generation with right padding.')\n        if self.transformer.attn_uses_sequence_id and self.training:\n            sequence_id = torch.zeros_like(input_ids[:1])\n        else:\n            sequence_id = None\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n        if self.transformer.prefix_lm:\n            prefix_mask = torch.ones_like(attention_mask)\n            if kwargs.get('use_cache') == False:\n                raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')\n        else:\n            prefix_mask = None\n        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        \"\"\"Used by HuggingFace generate when using beam search with kv-caching.\n\n        See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133\n        for an example in transformers.\n        \"\"\"\n        reordered_past = []\n        for layer_past in past_key_values:\n            reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]\n        return reordered_past"
  },
  {
    "path": "llava/model/language_model/mpt/norm.py",
    "content": "import torch\n\ndef _cast_if_autocast_enabled(tensor):\n    if torch.is_autocast_enabled():\n        if tensor.device.type == 'cuda':\n            dtype = torch.get_autocast_gpu_dtype()\n        elif tensor.device.type == 'cpu':\n            dtype = torch.get_autocast_cpu_dtype()\n        else:\n            raise NotImplementedError()\n        return tensor.to(dtype=dtype)\n    return tensor\n\nclass LPLayerNorm(torch.nn.LayerNorm):\n\n    def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):\n        super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)\n\n    def forward(self, x):\n        module_device = x.device\n        downcast_x = _cast_if_autocast_enabled(x)\n        downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight\n        downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias\n        with torch.autocast(enabled=False, device_type=module_device.type):\n            return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)\n\ndef rms_norm(x, weight=None, eps=1e-05):\n    output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)\n    if weight is not None:\n        return output * weight\n    return output\n\nclass RMSNorm(torch.nn.Module):\n\n    def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):\n        super().__init__()\n        self.eps = eps\n        if weight:\n            self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))\n        else:\n            self.register_parameter('weight', None)\n\n    def forward(self, x):\n        return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)\n\nclass LPRMSNorm(RMSNorm):\n\n    def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):\n        super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)\n\n    def forward(self, x):\n        downcast_x = _cast_if_autocast_enabled(x)\n        downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight\n        with torch.autocast(enabled=False, device_type=x.device.type):\n            return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)\nNORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}"
  },
  {
    "path": "llava/model/language_model/mpt/param_init_fns.py",
    "content": "import math\nimport warnings\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\nimport torch\nfrom torch import nn\nfrom .norm import NORM_CLASS_REGISTRY\n\ndef torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):\n    del kwargs\n    if verbose > 1:\n        warnings.warn(f\"Initializing network using module's reset_parameters attribute\")\n    if hasattr(module, 'reset_parameters'):\n        module.reset_parameters()\n\ndef fused_init_helper_(module: nn.Module, init_fn_):\n    _fused = getattr(module, '_fused', None)\n    if _fused is None:\n        raise RuntimeError(f'Internal logic error')\n    (dim, splits) = _fused\n    splits = (0, *splits, module.weight.size(dim))\n    for (s, e) in zip(splits[:-1], splits[1:]):\n        slice_indices = [slice(None)] * module.weight.ndim\n        slice_indices[dim] = slice(s, e)\n        init_fn_(module.weight[slice_indices])\n\ndef generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):\n    del kwargs\n    if verbose > 1:\n        warnings.warn(f'If model has bias parameters they are initialized to 0.')\n    init_div_is_residual = init_div_is_residual\n    if init_div_is_residual is False:\n        div_is_residual = 1.0\n    elif init_div_is_residual is True:\n        div_is_residual = math.sqrt(2 * n_layers)\n    elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):\n        div_is_residual = init_div_is_residual\n    elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():\n        div_is_residual = float(init_div_is_residual)\n    else:\n        div_is_residual = 1.0\n        raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')\n    if init_div_is_residual is not False:\n        if verbose > 1:\n            warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')\n    if isinstance(module, nn.Linear):\n        if hasattr(module, '_fused'):\n            fused_init_helper_(module, init_fn_)\n        else:\n            init_fn_(module.weight)\n        if module.bias is not None:\n            torch.nn.init.zeros_(module.bias)\n        if init_div_is_residual is not False and getattr(module, '_is_residual', False):\n            with torch.no_grad():\n                module.weight.div_(div_is_residual)\n    elif isinstance(module, nn.Embedding):\n        if emb_init_std is not None:\n            std = emb_init_std\n            if std == 0:\n                warnings.warn(f'Embedding layer initialized to 0.')\n            emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)\n            if verbose > 1:\n                warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')\n        elif emb_init_uniform_lim is not None:\n            lim = emb_init_uniform_lim\n            if isinstance(lim, Sequence):\n                if len(lim) > 2:\n                    raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')\n                if lim[0] == lim[1]:\n                    warnings.warn(f'Embedding layer initialized to {lim[0]}.')\n            else:\n                if lim == 0:\n                    warnings.warn(f'Embedding layer initialized to 0.')\n                lim = [-lim, lim]\n            (a, b) = lim\n            emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)\n            if verbose > 1:\n                warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')\n        else:\n            emb_init_fn_ = init_fn_\n        emb_init_fn_(module.weight)\n    elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):\n        if verbose > 1:\n            warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')\n        if hasattr(module, 'weight') and module.weight is not None:\n            torch.nn.init.ones_(module.weight)\n        if hasattr(module, 'bias') and module.bias is not None:\n            torch.nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.MultiheadAttention):\n        if module._qkv_same_embed_dim:\n            assert module.in_proj_weight is not None\n            assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)\n            assert d_model is not None\n            _d = d_model\n            splits = (0, _d, 2 * _d, 3 * _d)\n            for (s, e) in zip(splits[:-1], splits[1:]):\n                init_fn_(module.in_proj_weight[s:e])\n        else:\n            assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)\n            assert module.in_proj_weight is None\n            init_fn_(module.q_proj_weight)\n            init_fn_(module.k_proj_weight)\n            init_fn_(module.v_proj_weight)\n        if module.in_proj_bias is not None:\n            torch.nn.init.zeros_(module.in_proj_bias)\n        if module.bias_k is not None:\n            torch.nn.init.zeros_(module.bias_k)\n        if module.bias_v is not None:\n            torch.nn.init.zeros_(module.bias_v)\n        init_fn_(module.out_proj.weight)\n        if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):\n            with torch.no_grad():\n                module.out_proj.weight.div_(div_is_residual)\n        if module.out_proj.bias is not None:\n            torch.nn.init.zeros_(module.out_proj.bias)\n    else:\n        for _ in module.parameters(recurse=False):\n            raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')\n\ndef _normal_init_(std, mean=0.0):\n    return partial(torch.nn.init.normal_, mean=mean, std=std)\n\ndef _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):\n    del kwargs\n    init_fn_ = _normal_init_(std=std)\n    if verbose > 1:\n        warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')\n    generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):\n    del kwargs\n    if init_std is None:\n        raise ValueError(\"You must set model.init_config['init_std'] to a float value to use the default initialization scheme.\")\n    _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):\n    del kwargs\n    std = math.sqrt(2 / (5 * d_model))\n    _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):\n    \"\"\"From section 2.3.1 of GPT-NeoX-20B:\n\n    An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)\n    see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151\n    and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py\n    \"\"\"\n    del kwargs\n    residual_div = n_layers / math.sqrt(10)\n    if verbose > 1:\n        warnings.warn(f'setting init_div_is_residual to {residual_div}')\n    small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):\n    del kwargs\n    if verbose > 1:\n        warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')\n    kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)\n    generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):\n    del kwargs\n    if verbose > 1:\n        warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')\n    kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)\n    generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):\n    del kwargs\n    xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)\n    if verbose > 1:\n        warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')\n    generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\n\ndef xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):\n    xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)\n    if verbose > 1:\n        warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')\n    generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)\nMODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}"
  },
  {
    "path": "llava/model/llava_arch.py",
    "content": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\n\nfrom abc import ABC, abstractmethod\n\nimport torch\nimport torch.nn as nn\n\nfrom .multimodal_encoder.builder import build_vision_tower\nfrom .openseed import build_model\nfrom .openseed.BaseModel import BaseModel\n\ngrounding_start=\"<g_s>\"\ngrounding_end=\"<g_e>\"\nSEG_TOKEN=\"<seg>\"\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n\n\nclass LlavaMetaModel:\n\n    def __init__(self, config):\n        super(LlavaMetaModel, self).__init__(config)\n\n        if hasattr(config, \"mm_vision_tower\"):\n            self.vision_tower = build_vision_tower(config, delay_load=True)\n            self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)\n\n    def get_vision_tower(self):\n        vision_tower = getattr(self, 'vision_tower', None)\n        if type(vision_tower) is list:\n            vision_tower = vision_tower[0]\n        return vision_tower\n\n    def initialize_vision_modules(self, model_args, fsdp=None):\n        vision_tower = model_args.vision_tower\n        mm_vision_select_layer = model_args.mm_vision_select_layer\n        mm_vision_select_feature = model_args.mm_vision_select_feature\n        pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter\n\n        self.config.mm_vision_tower = vision_tower\n\n        vision_tower = build_vision_tower(model_args)\n\n        if fsdp is not None and len(fsdp) > 0:\n            self.vision_tower = [vision_tower]\n        else:\n            self.vision_tower = vision_tower\n\n        self.config.use_mm_proj = True\n        self.config.mm_hidden_size = vision_tower.hidden_size\n        self.config.mm_vision_select_layer = mm_vision_select_layer\n        self.config.mm_vision_select_feature = mm_vision_select_feature\n\n        if not hasattr(self, 'mm_projector'):\n            self.mm_projector = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size)\n\n        if pretrain_mm_mlp_adapter is not None:\n            mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')\n            def get_w(weights, keyword):\n                return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}\n\n            # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))\n            self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))\n\n\nclass LlavaMetaForCausalLM(ABC):\n\n    @abstractmethod\n    def get_model(self):\n        pass\n\n    def get_vision_tower(self):\n        return self.get_model().get_vision_tower()\n\n    def encode_images(self, images):\n        image_features = self.get_model().get_vision_tower()(images)\n        image_features = self.get_model().mm_projector(image_features)\n        return image_features\n\n    def prepare_inputs_labels_for_multimodal(\n        self, input_ids, attention_mask, past_key_values, labels, images\n    ):\n        vision_tower = self.get_vision_tower()\n        if vision_tower is None or images is None or input_ids.shape[1] == 1:\n            if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:\n                attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)\n            return input_ids, attention_mask, past_key_values, None, labels\n\n        if type(images) is list or images.ndim == 5:\n            concat_images = torch.cat([image for image in images], dim=0)\n            image_features = self.encode_images(concat_images)\n            split_sizes = [image.shape[0] for image in images]\n            image_features = torch.split(image_features, split_sizes, dim=0)\n            image_features = [x.flatten(0, 1) for x in image_features]\n        else:\n            image_features = self.encode_images(images)\n\n        new_input_embeds = []\n        new_labels = [] if labels is not None else None\n        cur_image_idx = 0\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            orig_embeds_params_in = orig_embeds_params[0]\n            orig_embeds_params_out = orig_embeds_params[1]\n            # st_inp=self.tokenizer.encode(grounding_start)[1]\n            # st_out=self.tokenizer.encode(grounding_start)[1]\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data\n                # if self.tokenizer.decode([len(self.tokenizer)-1])=='<seg>':\n                self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data\n        for batch_idx, cur_input_ids in enumerate(input_ids):\n            if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:\n                # multimodal LLM, but the current sample is not multimodal\n                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)\n                cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()\n                new_input_embeds.append(cur_input_embeds)\n                if labels is not None:\n                    new_labels.append(labels[batch_idx])\n                cur_image_idx += 1\n                continue\n            image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            cur_new_input_embeds = []\n            if labels is not None:\n                cur_labels = labels[batch_idx]\n                cur_new_labels = []\n                assert cur_labels.shape == cur_input_ids.shape\n            while image_token_indices.numel() > 0:\n                cur_image_features = image_features[cur_image_idx]\n                image_token_start = image_token_indices[0]\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])\n                        cur_labels = cur_labels[image_token_start+2:]\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_labels = cur_labels[image_token_start+1:]\n                cur_image_idx += 1\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_input_ids = cur_input_ids[image_token_start+2:]\n                else:\n                    cur_input_ids = cur_input_ids[image_token_start+1:]\n                image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            if cur_input_ids.numel() > 0:\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                if labels is not None:\n                    cur_new_labels.append(cur_labels)\n            cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]\n            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)\n            new_input_embeds.append(cur_new_input_embeds)\n            if labels is not None:\n                cur_new_labels = torch.cat(cur_new_labels, dim=0)\n                new_labels.append(cur_new_labels)\n\n        if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):\n            max_len = max(x.shape[0] for x in new_input_embeds)\n\n            new_input_embeds_align = []\n            for cur_new_embed in new_input_embeds:\n                cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)\n                new_input_embeds_align.append(cur_new_embed)\n            new_input_embeds = torch.stack(new_input_embeds_align, dim=0)\n\n            if labels is not None:\n                new_labels_align = []\n                _new_labels = new_labels\n                for cur_new_label in new_labels:\n                    cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)\n                    new_labels_align.append(cur_new_label)\n                new_labels = torch.stack(new_labels_align, dim=0)\n\n            if attention_mask is not None:\n                new_attention_mask = []\n                for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):\n                    new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                    new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)\n                    cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)\n                    new_attention_mask.append(cur_new_attention_mask)\n                attention_mask = torch.stack(new_attention_mask, dim=0)\n                assert attention_mask.shape == new_labels.shape\n        else:\n            new_input_embeds = torch.stack(new_input_embeds, dim=0)\n            if labels is not None:\n                new_labels  = torch.stack(new_labels, dim=0)\n\n            if attention_mask is not None:\n                new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)\n                assert attention_mask.shape == new_input_embeds.shape[:2]\n\n        return None, attention_mask, past_key_values, new_input_embeds, new_labels\n\n    def initialize_vision_tokenizer(self, model_args, tokenizer):\n        if model_args.mm_use_im_patch_token:\n            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n        if model_args.mm_use_im_start_end:\n            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\n            if model_args.pretrain_mm_mlp_adapter:\n                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')\n                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']\n                assert num_new_tokens == 2\n                if input_embeddings.shape == embed_tokens_weight.shape:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]\n                elif embed_tokens_weight.shape[0] == num_new_tokens:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight\n                else:\n                    raise ValueError(f\"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.\")\n        elif model_args.mm_use_im_patch_token:\n            if model_args.tune_mm_mlp_adapter:\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = False\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = False\n        else:\n            # import pdb; pdb.set_trace()\n            num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            inits=['[',']','.']\n            nums=[tokenizer.encode(init)[1] for init in inits]\n            # inp_embs = self.get_input_embeddings().weight.data[nums]\n            # out_embs = self.get_output_embeddings().weight.data[nums]\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                # print(\"Emb length:\", len(self.get_input_embeddings().weight.data))\n                # if len(self.get_input_embeddings().weight.data) > 0:\n                # if len(self.get_input_embeddings().weight.data) > 0:\n                # self.get_input_embeddings().weight.data[-num_new_tokens:] = inp_embs\n                # self.get_output_embeddings().weight.data[-num_new_tokens:] = out_embs\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n                #\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                #\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\nclass LlavaMetaForCausalLM_gd(ABC):\n\n    @abstractmethod\n    def get_model(self):\n        pass\n\n    def get_vision_tower(self):\n        return self.get_model().get_vision_tower()\n\n    def encode_images(self, images):\n        image_features = self.get_model().get_vision_tower()(images)\n        image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()[\"weight\"].dtype))\n        return image_features\n\n    def prepare_inputs_labels_for_multimodal(\n        self, input_ids, attention_mask, past_key_values, labels, images\n    ):\n        vision_tower = self.get_vision_tower()\n        if vision_tower is None or images is None or input_ids.shape[1] == 1:\n            if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:\n                attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)\n            return input_ids, attention_mask, past_key_values, None, labels\n\n        if type(images) is list or images.ndim == 5:\n            concat_images = torch.cat([image for image in images], dim=0)\n            image_features = self.encode_images(concat_images)\n            split_sizes = [image.shape[0] for image in images]\n            image_features = torch.split(image_features, split_sizes, dim=0)\n            image_features = [x.flatten(0, 1) for x in image_features]\n        else:\n            image_features = self.encode_images(images)\n\n        new_input_embeds = []\n        new_labels = [] if labels is not None else None\n        cur_image_idx = 0\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            orig_embeds_params_in = orig_embeds_params[0]\n            orig_embeds_params_out = orig_embeds_params[1]\n            # st_inp=self.tokenizer.encode(grounding_start)[1]\n            # st_out=self.tokenizer.encode(grounding_start)[1]\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data\n                # if self.tokenizer.decode([len(self.tokenizer)-1])=='<seg>':\n                self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data\n\n        for batch_idx, cur_input_ids in enumerate(input_ids):\n            if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:\n                # multimodal LLM, but the current sample is not multimodal\n                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)\n                cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()\n                new_input_embeds.append(cur_input_embeds)\n                if labels is not None:\n                    new_labels.append(labels[batch_idx])\n                cur_image_idx += 1\n                continue\n            image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            cur_new_input_embeds = []\n            if labels is not None:\n                cur_labels = labels[batch_idx]\n                cur_new_labels = []\n                assert cur_labels.shape == cur_input_ids.shape\n            while image_token_indices.numel() > 0:\n                cur_image_features = image_features[cur_image_idx]\n                image_token_start = image_token_indices[0]\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]))\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])\n                        cur_labels = cur_labels[image_token_start+2:]\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_labels = cur_labels[image_token_start+1:]\n                cur_image_idx += 1\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_input_ids = cur_input_ids[image_token_start+2:]\n                else:\n                    cur_input_ids = cur_input_ids[image_token_start+1:]\n                image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            if cur_input_ids.numel() > 0:\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                if labels is not None:\n                    cur_new_labels.append(cur_labels)\n            cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]\n            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)\n            new_input_embeds.append(cur_new_input_embeds)\n            if labels is not None:\n                cur_new_labels = torch.cat(cur_new_labels, dim=0)\n                new_labels.append(cur_new_labels)\n\n        if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):\n            max_len = max(x.shape[0] for x in new_input_embeds)\n\n            new_input_embeds_align = []\n            for cur_new_embed in new_input_embeds:\n                cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)\n                new_input_embeds_align.append(cur_new_embed)\n            new_input_embeds = torch.stack(new_input_embeds_align, dim=0)\n\n            if labels is not None:\n                new_labels_align = []\n                _new_labels = new_labels\n                for cur_new_label in new_labels:\n                    cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)\n                    new_labels_align.append(cur_new_label)\n                new_labels = torch.stack(new_labels_align, dim=0)\n\n            if attention_mask is not None:\n                new_attention_mask = []\n                for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):\n                    new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                    new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)\n                    cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)\n                    new_attention_mask.append(cur_new_attention_mask)\n                attention_mask = torch.stack(new_attention_mask, dim=0)\n                assert attention_mask.shape == new_labels.shape\n        else:\n            new_input_embeds = torch.stack(new_input_embeds, dim=0)\n            if labels is not None:\n                new_labels  = torch.stack(new_labels, dim=0)\n\n            if attention_mask is not None:\n                new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)\n                assert attention_mask.shape == new_input_embeds.shape[:2]\n\n        return None, attention_mask, past_key_values, new_input_embeds, new_labels\n\n    def initialize_vision_tokenizer(self, model_args, tokenizer):\n        if model_args.mm_use_im_patch_token:\n            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n        if model_args.mm_use_im_start_end:\n            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\n            if model_args.pretrain_mm_mlp_adapter:\n                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')\n                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']\n                assert num_new_tokens == 2\n                if input_embeddings.shape == embed_tokens_weight.shape:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]\n                elif embed_tokens_weight.shape[0] == num_new_tokens:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight\n                else:\n                    raise ValueError(f\"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.\")\n        elif model_args.mm_use_im_patch_token:\n            if model_args.tune_mm_mlp_adapter:\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = False\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = False\n        else:\n            # import pdb; pdb.set_trace()\n            num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            inits=['[',']','.']\n            nums=[tokenizer.encode(init)[1] for init in inits]\n\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n                #\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                #\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\n    def initialize_seg_modules(self, cfg):\n        seg_model = BaseModel(cfg, build_model(cfg))\n        seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS)\n        self.seg_model = seg_model\n\n    def freeze_seg_modules(self):\n        for p in self.seg_model.parameters():\n            p.requires_grad = False\n\n\nclass LlavaMetaForCausalLM_gd_interactive(ABC):\n\n    @abstractmethod\n    def get_model(self):\n        pass\n\n    def get_vision_tower(self):\n        return self.get_model().get_vision_tower()\n\n    def encode_images(self, images):\n        image_features = self.get_model().get_vision_tower()(images)\n        image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()[\"weight\"].dtype))\n        return image_features\n\n    def prepare_inputs_labels_for_multimodal(\n        self, input_ids, attention_mask, past_key_values, labels, images,obj_feats=None,num_it=0\n    ):\n        vision_tower = self.get_vision_tower()\n        if vision_tower is None or images is None or input_ids.shape[1] == 1:\n            if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:\n                attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)\n            return input_ids, attention_mask, past_key_values, None, labels\n\n        if type(images) is list or images.ndim == 5:\n            concat_images = torch.cat([image for image in images], dim=0)\n            image_features = self.encode_images(concat_images)\n            split_sizes = [image.shape[0] for image in images]\n            image_features = torch.split(image_features, split_sizes, dim=0)\n            image_features = [x.flatten(0, 1) for x in image_features]\n        else:\n            image_features = self.encode_images(images)\n\n        new_input_embeds = []\n        new_labels = [] if labels is not None else None\n        cur_image_idx = 0\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            orig_embeds_params_in = orig_embeds_params[0]\n            orig_embeds_params_out = orig_embeds_params[1]\n            # st_inp=self.tokenizer.encode(grounding_start)[1]\n            # st_out=self.tokenizer.encode(grounding_start)[1]\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data\n                # if self.tokenizer.decode([len(self.tokenizer)-1])=='<seg>':\n                self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data\n\n        for batch_idx, cur_input_ids in enumerate(input_ids):\n            if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:\n                # multimodal LLM, but the current sample is not multimodal\n                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)\n                cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()\n                new_input_embeds.append(cur_input_embeds)\n                if labels is not None:\n                    new_labels.append(labels[batch_idx])\n                cur_image_idx += 1\n                continue\n            image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            cur_new_input_embeds = []\n            if labels is not None:\n                cur_labels = labels[batch_idx]\n                cur_new_labels = []\n                assert cur_labels.shape == cur_input_ids.shape\n            while image_token_indices.numel() > 0:\n                cur_image_features = image_features[cur_image_idx]\n                image_token_start = image_token_indices[0]\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]))\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])\n                        cur_labels = cur_labels[image_token_start+2:]\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_labels = cur_labels[image_token_start+1:]\n                cur_image_idx += 1\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_input_ids = cur_input_ids[image_token_start+2:]\n                else:\n                    cur_input_ids = cur_input_ids[image_token_start+1:]\n                image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            if cur_input_ids.numel() > 0:\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                if batch_idx >= len(input_ids) - num_it:\n                    obj_idx = cur_input_ids == 1273\n                    idx_in_inter=batch_idx-(len(input_ids)-num_it)\n                    cur_new_input_embeds[-1][obj_idx] = obj_feats[idx_in_inter].to(cur_new_input_embeds[-1].dtype)\n                if labels is not None:\n                    cur_labels[cur_labels==1273]=IGNORE_INDEX\n                    cur_new_labels.append(cur_labels)\n            cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]\n            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)\n            new_input_embeds.append(cur_new_input_embeds)\n            if labels is not None:\n                cur_new_labels = torch.cat(cur_new_labels, dim=0)\n                new_labels.append(cur_new_labels)\n\n        if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):\n            max_len = max(x.shape[0] for x in new_input_embeds)\n\n            new_input_embeds_align = []\n            for cur_new_embed in new_input_embeds:\n                cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)\n                new_input_embeds_align.append(cur_new_embed)\n            new_input_embeds = torch.stack(new_input_embeds_align, dim=0)\n\n            if labels is not None:\n                new_labels_align = []\n                _new_labels = new_labels\n                for cur_new_label in new_labels:\n                    cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)\n                    new_labels_align.append(cur_new_label)\n                new_labels = torch.stack(new_labels_align, dim=0)\n\n            if attention_mask is not None:\n                new_attention_mask = []\n                for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):\n                    new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                    new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)\n                    cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)\n                    new_attention_mask.append(cur_new_attention_mask)\n                attention_mask = torch.stack(new_attention_mask, dim=0)\n                assert attention_mask.shape == new_labels.shape\n        else:\n            new_input_embeds = torch.stack(new_input_embeds, dim=0)\n            if labels is not None:\n                new_labels  = torch.stack(new_labels, dim=0)\n\n            if attention_mask is not None:\n                new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)\n                assert attention_mask.shape == new_input_embeds.shape[:2]\n\n        return None, attention_mask, past_key_values, new_input_embeds, new_labels\n    def prepare_inputs_labels_for_multimodal_NoInter(\n        self, input_ids, attention_mask, past_key_values, labels, images\n    ):\n        vision_tower = self.get_vision_tower()\n        if vision_tower is None or images is None or input_ids.shape[1] == 1:\n            if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:\n                attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)\n            return input_ids, attention_mask, past_key_values, None, labels\n\n        if type(images) is list or images.ndim == 5:\n            concat_images = torch.cat([image for image in images], dim=0)\n            image_features = self.encode_images(concat_images)\n            split_sizes = [image.shape[0] for image in images]\n            image_features = torch.split(image_features, split_sizes, dim=0)\n            image_features = [x.flatten(0, 1) for x in image_features]\n        else:\n            image_features = self.encode_images(images)\n\n        new_input_embeds = []\n        new_labels = [] if labels is not None else None\n        cur_image_idx = 0\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            orig_embeds_params_in = orig_embeds_params[0]\n            orig_embeds_params_out = orig_embeds_params[1]\n            # st_inp=self.tokenizer.encode(grounding_start)[1]\n            # st_out=self.tokenizer.encode(grounding_start)[1]\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data\n                # if self.tokenizer.decode([len(self.tokenizer)-1])=='<seg>':\n                self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data\n\n        for batch_idx, cur_input_ids in enumerate(input_ids):\n            if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:\n                # multimodal LLM, but the current sample is not multimodal\n                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)\n                cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()\n                new_input_embeds.append(cur_input_embeds)\n                if labels is not None:\n                    new_labels.append(labels[batch_idx])\n                cur_image_idx += 1\n                continue\n            image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            cur_new_input_embeds = []\n            if labels is not None:\n                cur_labels = labels[batch_idx]\n                cur_new_labels = []\n                assert cur_labels.shape == cur_input_ids.shape\n            while image_token_indices.numel() > 0:\n                cur_image_features = image_features[cur_image_idx]\n                image_token_start = image_token_indices[0]\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]))\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])\n                        cur_labels = cur_labels[image_token_start+2:]\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))\n                    cur_new_input_embeds.append(cur_image_features)\n                    if labels is not None:\n                        cur_new_labels.append(cur_labels[:image_token_start])\n                        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))\n                        cur_labels = cur_labels[image_token_start+1:]\n                cur_image_idx += 1\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_input_ids = cur_input_ids[image_token_start+2:]\n                else:\n                    cur_input_ids = cur_input_ids[image_token_start+1:]\n                image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]\n            if cur_input_ids.numel() > 0:\n                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                else:\n                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))\n                if labels is not None:\n                    cur_new_labels.append(cur_labels)\n            cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]\n            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)\n            new_input_embeds.append(cur_new_input_embeds)\n            if labels is not None:\n                cur_new_labels = torch.cat(cur_new_labels, dim=0)\n                new_labels.append(cur_new_labels)\n\n        if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):\n            max_len = max(x.shape[0] for x in new_input_embeds)\n\n            new_input_embeds_align = []\n            for cur_new_embed in new_input_embeds:\n                cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)\n                new_input_embeds_align.append(cur_new_embed)\n            new_input_embeds = torch.stack(new_input_embeds_align, dim=0)\n\n            if labels is not None:\n                new_labels_align = []\n                _new_labels = new_labels\n                for cur_new_label in new_labels:\n                    cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)\n                    new_labels_align.append(cur_new_label)\n                new_labels = torch.stack(new_labels_align, dim=0)\n\n            if attention_mask is not None:\n                new_attention_mask = []\n                for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):\n                    new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                    new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)\n                    cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)\n                    new_attention_mask.append(cur_new_attention_mask)\n                attention_mask = torch.stack(new_attention_mask, dim=0)\n                assert attention_mask.shape == new_labels.shape\n        else:\n            new_input_embeds = torch.stack(new_input_embeds, dim=0)\n            if labels is not None:\n                new_labels  = torch.stack(new_labels, dim=0)\n\n            if attention_mask is not None:\n                new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)\n                assert attention_mask.shape == new_input_embeds.shape[:2]\n\n        return None, attention_mask, past_key_values, new_input_embeds, new_labels\n\n    def initialize_vision_tokenizer(self, model_args, tokenizer):\n        if model_args.mm_use_im_patch_token:\n            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n        if model_args.mm_use_im_start_end:\n            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\n            if model_args.pretrain_mm_mlp_adapter:\n                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')\n                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']\n                assert num_new_tokens == 2\n                if input_embeddings.shape == embed_tokens_weight.shape:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]\n                elif embed_tokens_weight.shape[0] == num_new_tokens:\n                    input_embeddings[-num_new_tokens:] = embed_tokens_weight\n                else:\n                    raise ValueError(f\"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.\")\n        elif model_args.mm_use_im_patch_token:\n            if model_args.tune_mm_mlp_adapter:\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = False\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = False\n        else:\n            # import pdb; pdb.set_trace()\n            num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True)\n            inits=['[',']','.']\n            nums=[tokenizer.encode(init)[1] for init in inits]\n            # inp_embs = self.get_input_embeddings().weight.data[nums]\n            # out_embs = self.get_output_embeddings().weight.data[nums]\n            self.resize_token_embeddings(len(tokenizer))\n\n            if num_new_tokens > 0:\n                input_embeddings = self.get_input_embeddings().weight.data\n                output_embeddings = self.get_output_embeddings().weight.data\n                #\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                #\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(),\n                                                  self.get_output_embeddings().weight.data.clone().cuda()]\n\n                for p in self.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in self.get_output_embeddings().parameters():\n                    p.requires_grad = True\n\n    def initialize_seg_modules(self, cfg):\n        seg_model = BaseModel(cfg, build_model(cfg))\n        seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS)\n        self.seg_model = seg_model\n\n    def initialize_interactive_modules(self, cfg):\n        from .semsam.BaseModel import BaseModel as SemSamBaseModel\n        from .semsam import build_model as build_semsam_model\n\n        seg_model = SemSamBaseModel(cfg, build_semsam_model(cfg))\n        if not (cfg.MODEL.WEIGHTS == \"None\"):\n            seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS)\n        self.interactive_model = seg_model\n    def freeze_seg_modules(self):\n        for p in self.seg_model.parameters():\n            p.requires_grad = False"
  },
  {
    "path": "llava/model/make_delta.py",
    "content": "\"\"\"\nUsage:\npython3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta\n\"\"\"\nimport argparse\n\nimport torch\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom llava.model.utils import auto_upgrade\n\n\ndef make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):\n    print(\"Loading base model\")\n    base = AutoModelForCausalLM.from_pretrained(\n        base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)\n\n    print(\"Loading target model\")\n    auto_upgrade(target_model_path)\n    target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)\n\n    print(\"Calculating delta\")\n    for name, param in tqdm(target.state_dict().items(), desc=\"Calculating delta\"):\n        if name not in base.state_dict():\n            assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'\n            continue\n        if param.data.shape == base.state_dict()[name].shape:\n            param.data -= base.state_dict()[name]\n        else:\n            assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'\n            bparam = base.state_dict()[name]\n            param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam\n\n    print(\"Saving delta\")\n    if hub_repo_id:\n        kwargs = {\"push_to_hub\": True, \"repo_id\": hub_repo_id}\n    else:\n        kwargs = {}\n    target.save_pretrained(delta_path, **kwargs)\n    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)\n    target_tokenizer.save_pretrained(delta_path, **kwargs)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--base-model-path\", type=str, required=True)\n    parser.add_argument(\"--target-model-path\", type=str, required=True)\n    parser.add_argument(\"--delta-path\", type=str, required=True)\n    parser.add_argument(\"--hub-repo-id\", type=str, default=None)\n    args = parser.parse_args()\n\n    make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)\n"
  },
  {
    "path": "llava/model/multimodal_encoder/builder.py",
    "content": "from .clip_encoder import CLIPVisionTower\n\n\ndef build_vision_tower(vision_tower_cfg, **kwargs):\n    vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))\n    if vision_tower.startswith(\"openai\") or vision_tower.startswith(\"laion\"):\n        return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)\n\n    raise ValueError(f'Unknown vision tower: {vision_tower}')\n"
  },
  {
    "path": "llava/model/multimodal_encoder/clip_encoder.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig\n\n\nclass CLIPVisionTower(nn.Module):\n    def __init__(self, vision_tower, args, delay_load=False):\n        super().__init__()\n\n        self.is_loaded = False\n\n        self.vision_tower_name = vision_tower\n        self.select_layer = args.mm_vision_select_layer\n        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')\n\n        if not delay_load:\n            self.load_model()\n        else:\n            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)\n\n    def load_model(self):\n        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)\n        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\")\n        self.vision_tower.requires_grad_(False)\n\n        self.is_loaded = True\n\n    def feature_select(self, image_forward_outs):\n        image_features = image_forward_outs.hidden_states[self.select_layer]\n        if self.select_feature == 'patch':\n            image_features = image_features[:, 1:]\n        elif self.select_feature == 'cls_patch':\n            image_features = image_features\n        else:\n            raise ValueError(f'Unexpected select feature: {self.select_feature}')\n        return image_features\n\n    @torch.no_grad()\n    def forward(self, images):\n        if type(images) is list:\n            image_features = []\n            for image in images:\n                image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)\n                image_feature = self.feature_select(image_forward_out).to(image.dtype)\n                image_features.append(image_feature)\n        else:\n            image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)\n            image_features = self.feature_select(image_forward_outs).to(images.dtype)\n\n        return image_features\n\n    @property\n    def dummy_feature(self):\n        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)\n\n    @property\n    def dtype(self):\n        return self.vision_tower.dtype\n\n    @property\n    def device(self):\n        return self.vision_tower.device\n\n    @property\n    def config(self):\n        if self.is_loaded:\n            return self.vision_tower.config\n        else:\n            return self.cfg_only\n\n    @property\n    def hidden_size(self):\n        return self.config.hidden_size\n\n    @property\n    def num_patches(self):\n        return (self.config.image_size // self.config.patch_size) ** 2\n"
  },
  {
    "path": "llava/model/openseed/BaseModel.py",
    "content": "import os\nimport logging\n\nimport torch\nimport torch.nn as nn\n\n# from utils.model import align_and_update_state_dicts\n\nlogger = logging.getLogger(__name__)\n\n\ndef align_and_update_state_dicts(model_state_dict, ckpt_state_dict):\n    model_keys = sorted(model_state_dict.keys())\n    ckpt_keys = sorted(ckpt_state_dict.keys())\n    result_dicts = {}\n    matched_log = []\n    unmatched_log = []\n    unloaded_log = []\n    for model_key in model_keys:\n        model_weight = model_state_dict[model_key]\n        if model_key in ckpt_keys:\n            ckpt_weight = ckpt_state_dict[model_key]\n            if model_weight.shape == ckpt_weight.shape:\n                result_dicts[model_key] = ckpt_weight\n                ckpt_keys.pop(ckpt_keys.index(model_key))\n                matched_log.append(\"Loaded {}, Model Shape: {} <-> Ckpt Shape: {}\".format(model_key, model_weight.shape,\n                                                                                          ckpt_weight.shape))\n            else:\n                unmatched_log.append(\n                    \"*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}\".format(model_key, model_weight.shape,\n                                                                                ckpt_weight.shape))\n        else:\n            unloaded_log.append(\"*UNLOADED* {}, Model Shape: {}\".format(model_key, model_weight.shape))\n\n    # if is_main_process():\n    #     for info in matched_log:\n    #         logger.info(info)\n    #     for info in unloaded_log:\n    #         logger.warning(info)\n    #     for key in ckpt_keys:\n    #         logger.warning(\"$UNUSED$ {}, Ckpt Shape: {}\".format(key, ckpt_state_dict[key].shape))\n    #     for info in unmatched_log:\n    #         logger.warning(info)\n    return result_dicts\n\nclass BaseModel(nn.Module):\n    def __init__(self, opt, module: nn.Module):\n        super(BaseModel, self).__init__()\n        self.opt = opt\n        self.model = module\n\n    def forward(self, *inputs, **kwargs):\n        outputs = self.model(*inputs, **kwargs)\n        return outputs\n\n    def save_pretrained(self, save_dir):\n        torch.save(self.model.state_dict(), save_dir)\n\n    def from_pretrained(self, load_dir):\n        state_dict = torch.load(load_dir, map_location='cpu')\n        state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)\n        self.model.load_state_dict(state_dict, strict=False)\n        return self"
  },
  {
    "path": "llava/model/openseed/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .architectures import build_model"
  },
  {
    "path": "llava/model/openseed/architectures/__init__.py",
    "content": "from .openseed_model import *\n# from .openseed_model_decouple_train import *\nfrom .build import build_model"
  },
  {
    "path": "llava/model/openseed/architectures/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\ndef build_model(config, **kwargs):\n    model_name = config['MODEL']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/openseed/architectures/openseed_model.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\nfrom typing import Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .registry import register_model\nfrom ..utils import configurable, box_ops #, get_class_names\nfrom ..backbone import build_backbone, Backbone\nfrom ..body import build_openseed_head\nfrom ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion\n\nfrom detectron2.structures import Boxes, ImageList, Instances, BitMasks\nfrom detectron2.utils.memory import retry_if_cuda_oom\nfrom detectron2.data import MetadataCatalog\nimport random\n\nclass OpenSeeD(nn.Module):\n    \"\"\"\n    Main class for mask classification semantic segmentation architectures.\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        *,\n        backbone: Backbone,\n        sem_seg_head: nn.Module,\n        criterion: nn.Module,\n        num_queries: int,\n        object_mask_threshold: float,\n        overlap_threshold: float,\n        metadata,\n        size_divisibility: int,\n        sem_seg_postprocess_before_inference: bool,\n        pixel_mean: Tuple[float],\n        pixel_std: Tuple[float],\n        # inference\n        semantic_on: bool,\n        panoptic_on: bool,\n        instance_on: bool,\n        test_topk_per_image: int,\n        data_loader: str,\n        pano_temp: float,\n        focus_on_box: bool = False,\n        transform_eval: bool = False,\n        semantic_ce_loss: bool = False,\n        train_dataset_name: str,\n        background: bool,\n        coco_on=True,\n        coco_mask_on=True,\n        o365_on=True,\n        merge_class=False,\n        coco_only=False,\n        detach_seg=False,\n        eval_train=False,\n    ):\n        \"\"\"\n        Args:\n            backbone: a backbone module, must follow detectron2's backbone interface\n            sem_seg_head: a module that predicts semantic segmentation from backbone features\n            criterion: a module that defines the loss\n            num_queries: int, number of queries\n            object_mask_threshold: float, threshold to filter query based on classification score\n                for panoptic segmentation inference\n            overlap_threshold: overlap threshold used in general inference for panoptic segmentation\n            metadata: dataset meta, get `thing` and `stuff` category names for panoptic\n                segmentation inference\n            size_divisibility: Some backbones require the input height and width to be divisible by a\n                specific integer. We can use this to override such requirement.\n            sem_seg_postprocess_before_inference: whether to resize the prediction back\n                to original input size before semantic segmentation inference or after.\n                For high-resolution dataset like Mapillary, resizing predictions before\n                inference will cause OOM error.\n            pixel_mean, pixel_std: list or tuple with #channels element, representing\n                the per-channel mean and std to be used to normalize the input image\n            semantic_on: bool, whether to output semantic segmentation prediction\n            instance_on: bool, whether to output instance segmentation prediction\n            panoptic_on: bool, whether to output panoptic segmentation prediction\n            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image\n        \"\"\"\n        super().__init__()\n        self.backbone = backbone\n        self.pano_temp = pano_temp\n        self.sem_seg_head = sem_seg_head\n        self.criterion = criterion\n        self.num_queries = num_queries\n        self.overlap_threshold = overlap_threshold\n        self.object_mask_threshold = object_mask_threshold\n        self.metadata = metadata\n        if size_divisibility < 0:\n            # use backbone size_divisibility if not set\n            size_divisibility = self.backbone.size_divisibility\n        self.size_divisibility = size_divisibility\n        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference\n        self.register_buffer(\"pixel_mean\", torch.Tensor(pixel_mean).view(-1, 1, 1), False)\n        self.register_buffer(\"pixel_std\", torch.Tensor(pixel_std).view(-1, 1, 1), False)\n        self.detach_seg=detach_seg\n        self.eval_train=eval_train\n        # additional args\n        self.semantic_on = semantic_on\n        self.instance_on = instance_on\n        self.panoptic_on = panoptic_on\n        self.test_topk_per_image = test_topk_per_image\n\n        self.data_loader = data_loader\n        self.focus_on_box = focus_on_box\n        self.transform_eval = transform_eval\n        self.semantic_ce_loss = semantic_ce_loss\n\n        self.train_class_names = dict()\n        self.train_dataset_name = train_dataset_name\n        self.coco_mask_on = coco_mask_on\n        self.task_switch = {'coco': coco_on, 'o365': o365_on}\n        self.num_correct_gd=0\n        self.num_total_gd=0\n        self.num_correct_ref = 0\n        self.num_total_ref = 0\n        self.num_correct_coco = 0\n        self.num_total_coco = 0\n        self.coco_only=coco_only\n        self.loss_dict=None\n        self.mean_iou=0.0\n        ########\n        self.total_union=0.0\n        self.total_intersection=0.0\n        # self.cIoU=0.0\n        print(\"self.task_switch \", self.task_switch)\n        # HACK for only two datasets for seg and det\n        if not self.semantic_on:\n            assert self.sem_seg_postprocess_before_inference\n\n    @classmethod\n    def from_config(cls, cfg):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        # Loss parameters:\n        deep_supervision = dec_cfg['DEEP_SUPERVISION']\n        no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']\n\n        # loss weights\n        class_weight = dec_cfg['CLASS_WEIGHT']\n        cost_class_weight = dec_cfg['COST_CLASS_WEIGHT']\n        cost_dice_weight = dec_cfg['COST_DICE_WEIGHT']\n        dice_weight = dec_cfg['DICE_WEIGHT']\n        cost_mask_weight = dec_cfg['COST_MASK_WEIGHT']\n        mask_weight = dec_cfg['MASK_WEIGHT']\n        cost_box_weight = dec_cfg['COST_BOX_WEIGHT']\n        box_weight = dec_cfg['BOX_WEIGHT']\n        cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT']\n        giou_weight = dec_cfg['GIOU_WEIGHT']\n\n        # building matcher\n        matcher = HungarianMatcher(\n            cost_class=cost_class_weight,\n            cost_mask=cost_mask_weight,\n            cost_dice=cost_dice_weight,\n            cost_box=cost_box_weight,\n            cost_giou=cost_giou_weight,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n        )\n\n        # MaskDINO losses and weight_dict\n        weight_dict = {\"loss_mask_cls_0\": class_weight}\n        weight_dict.update({\"loss_mask_bce_0\": mask_weight, \"loss_mask_dice_0\": dice_weight})\n        weight_dict.update({\"loss_bbox_0\":box_weight,\"loss_giou_0\":giou_weight})\n        # two stage is the query selection scheme\n        if dec_cfg['TWO_STAGE']:\n            interm_weight_dict = {}\n            interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})\n            weight_dict.update(interm_weight_dict)\n        # denoising training\n        dn = dec_cfg['DN']\n        dn = 'no'\n        # TODO hack for dn lable loss\n        if dn == \"standard\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items() if k!=\"loss_mask\" and k!=\"loss_dice\" })\n            dn_losses=[\"dn_labels\", \"boxes\"]\n        elif dn == \"seg\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items()})\n            dn_losses=[\"dn_labels\", \"masks\", \"boxes\"]\n        else:\n            dn_losses=[]\n        if deep_supervision:\n            dec_layers = dec_cfg['DEC_LAYERS']\n            aux_weight_dict = {}\n            for i in range(dec_layers):\n                aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()})\n            weight_dict.update(aux_weight_dict)\n        if dec_cfg['BOX']:\n            losses = [\"labels\", \"masks\",\"boxes\"]\n        else:\n            losses = [\"labels\", \"masks\"]\n\n        # update task switch\n        task_switch = {}\n        task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)})\n        top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),\n                        'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)}\n        weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0)\n        weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()}\n        # building criterion\n        criterion = SetCriterion(\n            enc_cfg['NUM_CLASSES'],\n            matcher=matcher,\n            weight_dict=weight_dict,\n            top_x_layers=top_x_layers,\n            eos_coef=no_object_weight,\n            losses=losses,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n            oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],\n            importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],\n            grounding_weight=None,\n            dn=dec_cfg['DN'],\n            dn_losses=dn_losses,\n            panoptic_on=dec_cfg['PANO_BOX_LOSS'],\n            semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'],\n        )\n\n        # build model\n        extra = {'task_switch': task_switch}\n        backbone = build_backbone(cfg)\n        # lang_encoder = build_language_encoder(cfg)\n        sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra)\n\n        return {\n            \"backbone\": backbone,\n            \"sem_seg_head\": sem_seg_head,\n            \"criterion\": criterion,\n            \"num_queries\": dec_cfg['NUM_OBJECT_QUERIES'],\n            \"object_mask_threshold\": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],\n            \"overlap_threshold\": dec_cfg['TEST']['OVERLAP_THRESHOLD'],\n            \"metadata\": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),\n            \"size_divisibility\": dec_cfg['SIZE_DIVISIBILITY'],\n            \"sem_seg_postprocess_before_inference\": (\n                dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']\n                or dec_cfg['TEST']['PANOPTIC_ON']\n                or dec_cfg['TEST']['INSTANCE_ON']\n            ),\n            \"pixel_mean\": cfg['INPUT']['PIXEL_MEAN'],\n            \"pixel_std\": cfg['INPUT']['PIXEL_STD'],\n            # inference\n            \"semantic_on\": dec_cfg['TEST']['SEMANTIC_ON'],\n            \"instance_on\": dec_cfg['TEST']['INSTANCE_ON'],\n            \"panoptic_on\": dec_cfg['TEST']['PANOPTIC_ON'],\n            \"test_topk_per_image\": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],\n            \"data_loader\": None,\n            \"focus_on_box\": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'],\n            \"transform_eval\": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'],\n            \"pano_temp\": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'],\n            \"semantic_ce_loss\": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'],\n            \"train_dataset_name\": cfg['DATASETS']['TRAIN'], # HACK for only two training set\n            \"background\": cfg['MODEL'].get('BACKGROUND', True),\n            \"coco_on\": dec_cfg.get('COCO', True),\n            \"coco_mask_on\": dec_cfg.get('COCO_MASK', True),\n            \"o365_on\": dec_cfg.get('O365', True),\n            \"coco_only\": dec_cfg.get('COCO_ONLY', False),\n            \"detach_seg\": cfg.get('detach_seg', False),\n            \"eval_train\": cfg.get('eval_train', False),\n        }\n\n    @property\n    def device(self):\n        return self.pixel_mean.device\n\n    def forward(self, batched_inputs, inference_task='seg'):\n        # import ipdb; ipdb.set_trace()\n        # print(\"Num images per batch:\",len(batched_inputs['flickr']))\n        if self.training:\n            losses = {}\n            losses_ = dict()\n            if 'flickr' in batched_inputs and not self.coco_only:\n                self.criterion.conversation=False\n                losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg',default_text_embeddings=batched_inputs['flickr_text_embeddings'],data_type='gd')\n                for key, value in losses_flickr.items():\n                    losses_['flickr.'+str(key)] = losses_flickr[key]\n                self.loss_dict=losses_flickr\n            if 'refcoco' in batched_inputs and not self.coco_only:\n                self.criterion.conversation=False\n                losses_ref = self.forward_seg(batched_inputs['refcoco'], task='seg',default_text_embeddings=batched_inputs['refcoco_text_embeddings'],data_type='ref')\n                for key, value in losses_ref.items():\n                    losses_['refcoco.'+str(key)] = losses_ref[key]\n            if 'vg' in batched_inputs and not self.coco_only:\n                self.criterion.conversation = False\n                losses_ref = self.forward_seg(batched_inputs['vg'], task='det',\n                                              default_text_embeddings=batched_inputs['vg_text_embeddings'],\n                                              data_type='ref')\n                for key, value in losses_ref.items():\n                    losses_['vg.' + str(key)] = losses_ref[key]\n                # self.loss_dict=losses_flickr\n            if 'coco' in batched_inputs:\n                # if self.loss_dict is None:\n                #\n                # else:\n                valid_idx=[]\n                for idx,input in enumerate(batched_inputs['coco']):\n                    if input['grounding']:\n                        valid_idx.append(idx)\n                if len(valid_idx)==0:\n                    self.criterion.conversation = True\n                    losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg',\n                                                     default_text_embeddings=batched_inputs[\n                                                         'flickr_text_embeddings'], data_type='coco')\n                    self.loss_dict = losses_flickr\n                    for key, value in self.loss_dict.items():\n                        losses['coco.' + str(key)] = self.loss_dict[key] * 0.0\n                else:\n                    batched_inputs['coco']=[batched_inputs['coco'][idx] for idx in valid_idx]\n                    text_embed=batched_inputs['coco_text_embeddings']\n                    text_embed=text_embed[0][valid_idx],text_embed[1][valid_idx]\n                    self.criterion.conversation = True\n                    losses_coco_instruct = self.forward_seg(batched_inputs['coco'], task='seg',default_text_embeddings=text_embed)\n                    for key, value in losses_coco_instruct.items():\n                        losses['coco.'+str(key)] = losses_coco_instruct[key]\n            losses.update(losses_)\n            # if self.task_switch['coco']:\n            #     self.criterion.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80\n            #     # self.criterion.num_classes = 133\n            #     task = 'seg'\n            #     if not self.coco_mask_on:\n            #         task = 'det'\n            #     # import ipdb; ipdb.set_trace()\n            #     losses_coco = self.forward_seg(batched_inputs['coco'], task=task)\n            #     new_losses_coco = {}\n            #     for key, value in losses_coco.items():\n            #         new_losses_coco['coco.'+str(key)] = losses_coco[key]\n            #     losses.update(new_losses_coco)\n            # if self.task_switch['o365']:\n            #     self.criterion.num_classes = 365\n            #     losses_o365 = self.forward_seg(batched_inputs['o365'], task='det')\n            #     new_losses_o365 = {}\n            #     for key, value in losses_o365.items():\n            #         new_losses_o365['o365.'+str(key)] = losses_o365[key]\n            #     losses.update(new_losses_o365)\n            return losses\n        else:\n            processed_results = self.forward_seg(batched_inputs, task=inference_task)\n            return processed_results\n\n    def forward_seg(self, batched_inputs, task='seg',default_text_embeddings=None,data_type='gd'):\n\n        images = [x[\"image\"].to(self.device) for x in batched_inputs]\n        images = [(x - self.pixel_mean) / self.pixel_std for x in images]\n        images = ImageList.from_tensors(images, self.size_divisibility)\n\n        features = self.backbone(images.tensor)\n        # features={k:v.to(torch.bfloat16) for k,v in features.items()}\n        if self.training:\n            # mask classification target\n            if \"instances\" in batched_inputs[0]:\n                gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n                targets = self.prepare_targets(gt_instances, images, task=task)\n            else:\n                targets = None\n            outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings)\n            ##########eval training\n            if self.eval_train:\n                pred_logits=outputs[\"pred_logits\"]\n                pred_boxes=outputs[\"pred_boxes\"]\n                pred_masks=outputs[\"pred_masks\"]>0\n                num_total=0\n                num_correct=0\n                mask_iou=0.0\n                # total_union=0.0\n                # total_intersection=0.0\n                scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs]\n                for i in range(len(pred_logits)):\n                    matched_idx=torch.argmax(pred_logits[i],dim=0)\n                    matched_boxes=pred_boxes[i][matched_idx]\n                    matched_masks=pred_masks[i][matched_idx]\n                    gt_boxes_=targets[i]['boxes']\n                    gt_masks_=targets[i]['masks']\n\n                    gt_ground_labels=targets[i]['labels']\n                    gt_ground_labels_=[]\n                    for lb in gt_ground_labels:\n                        gt_ground_labels_.extend(lb)\n                    max_lb=max(gt_ground_labels_)\n                    lb2gt_idx=dict()\n                    for lb in range(max_lb+1):\n                        lb2gt_idx[lb]=[]\n                    for idx,lbs in enumerate(gt_ground_labels):\n                        for lb in lbs:\n                            lb2gt_idx[lb].append(idx)\n                    for lb in range(max_lb+1):\n                        pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None])\n                        gt_boxes=box_ops.box_cxcywh_to_xyxy(gt_boxes_[lb2gt_idx[lb]])\n                        pred_mask=matched_masks[lb]\n\n                        gt_mask=gt_masks_[lb2gt_idx[lb]][0]\n                        pred_mask = F.interpolate(\n                            pred_mask[None,None].float(),\n                            size=(gt_mask.shape[-2], gt_mask.shape[-1]),\n                            mode=\"bilinear\",\n                            align_corners=False,\n                        )[0,0]>0.5\n                        if len(gt_boxes)==0:\n                            continue\n                        mask_iou+=float(torch.sum(pred_mask*gt_mask)/torch.sum(torch.logical_or(pred_mask,gt_mask)))\n                        self.total_union+=float(torch.sum(torch.logical_or(pred_mask,gt_mask)))/scale_factor[i]**2\n                        self.total_intersection+=float(torch.sum(pred_mask*gt_mask))/scale_factor[i]**2\n                        # self.mask_iou+=mask_iou\n\n                        if box_ops.box_iou(pred_box,gt_boxes)[0].max()>0.5:\n                            num_correct+=1\n                        else:\n                            pass\n                        num_total+=1\n                print(f\"{data_type} cIoU:\" ,self.total_intersection/self.total_union)\n                name_correct='num_correct_'+data_type\n                name_total='num_total_'+data_type\n                try:\n                    gathered_list=[None for _ in range(torch.distributed.get_world_size())]\n                    torch.distributed.all_gather_object(gathered_list,num_correct)\n                    # self.num_correct+=sum(gathered_list)\n                    num_correct_value=getattr(self,name_correct)\n                    setattr(self,name_correct,num_correct_value+sum(gathered_list))\n                    gathered_list=[None for _ in range(torch.distributed.get_world_size())]\n                    torch.distributed.all_gather_object(gathered_list,num_total)\n                    # self.num_total+=sum(gathered_list)\n                    num_total_value=getattr(self,name_total)\n                    setattr(self,name_total,num_total_value+sum(gathered_list))\n                    gathered_list=[None for _ in range(torch.distributed.get_world_size())]\n                    torch.distributed.all_gather_object(gathered_list,mask_iou)\n                    # self.mask_iou+=sum(gathered_list)\n                    if torch.distributed.get_rank()==0:\n                        print(f\"{data_type} acc: \",getattr(self, name_correct) / getattr(self, name_total))\n                        # print(\"mask_iou: \",self.mask_iou/self.num_total)\n                except Exception as e:\n                    # self.num_correct+=num_correct\n                    # self.num_total+=num_total\n                    num_correct_value = getattr(self, name_correct)\n                    setattr(self, name_correct, num_correct_value + num_correct)\n                    num_total_value = getattr(self, name_total)\n                    setattr(self, name_total, num_total_value + num_total)\n                    try:\n                        print(f\"{data_type} rank{torch.distributed.get_rank()} acc: \", getattr(self, name_correct) / getattr(self, name_total))\n                    except Exception as e:\n                        print(f\"{data_type} acc: \", getattr(self, name_correct) / getattr(self, name_total))\n                ###########################\n            # bipartite matching-based loss\n            self.criterion.default_text_embeddings = default_text_embeddings\n            losses = self.criterion(outputs, targets, mask_dict, task=task)\n\n            for k in list(losses.keys()):\n                if k in self.criterion.weight_dict:\n                    losses[k] *= self.criterion.weight_dict[k]\n                else:\n                    # remove this loss if not specified in `weight_dict`\n                    losses.pop(k)\n            return losses\n        else:\n            outputs, _ = self.sem_seg_head(features)\n            mask_cls_results = outputs[\"pred_logits\"]\n            mask_box_results = outputs[\"pred_boxes\"]\n            if 'seg' in task:\n                if task == 'seg':\n                    self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True\n                if task == 'inst_seg':\n                    self.semantic_on = self.panoptic_on = False\n                    self.instance_on = True\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'sem_pan_seg':\n                    self.semantic_on = self.panoptic_on = True\n                    self.instance_on = False\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'inst_pan_seg':\n                    self.instance_on = self.panoptic_on = True\n                    self.semantic_on = False\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'sem_seg':\n                    self.instance_on = self.panoptic_on = False\n                    self.semantic_on = True\n                    self.sem_seg_postprocess_before_inference = True\n                mask_pred_results = outputs[\"pred_masks\"]\n                # upsample masks\n                mask_pred_results = F.interpolate(\n                    mask_pred_results,\n                    size=(images.tensor.shape[-2], images.tensor.shape[-1]),\n                    mode=\"bilinear\",\n                    align_corners=False,\n                )\n\n            else:\n                self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False\n                self.instance_on = True\n                mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results)\n\n            del outputs\n\n            processed_results = []\n\n            for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip(\n                mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes\n            ):\n                height = input_per_image.get(\"height\", image_size[0])\n                width = input_per_image.get(\"width\", image_size[1])\n                processed_results.append({})\n                new_size = (images.tensor.shape[-2], images.tensor.shape[-1])  # padded size (divisible to 32)\n\n\n                if self.sem_seg_postprocess_before_inference:\n                    mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(\n                        mask_pred_result, image_size, height, width\n                    )\n                    mask_cls_result = mask_cls_result.to(mask_pred_result)\n\n                # semantic segmentation inference\n                if self.semantic_on:\n                    r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)\n                    if not self.sem_seg_postprocess_before_inference:\n                        r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)\n                    processed_results[-1][\"sem_seg\"] = r\n\n                # panoptic segmentation inference\n                if self.panoptic_on:\n                    panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)\n                    processed_results[-1][\"panoptic_seg\"] = panoptic_r\n\n                # instance segmentation inference\n\n                if self.instance_on:\n                    mask_box_result = mask_box_result.to(mask_pred_result)\n                    height = new_size[0]/image_size[0]*height\n                    width = new_size[1]/image_size[1]*width\n                    mask_box_result = self.box_postprocess(mask_box_result, height, width)\n\n                    instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result)\n                    processed_results[-1][\"instances\"] = instance_r\n            del mask_pred_results\n            return processed_results\n\n    def prepare_targets(self, targets, images, task='seg'):\n        h_pad, w_pad = images.tensor.shape[-2:]\n        new_targets = []\n        for targets_per_image in targets:\n            # pad gt\n            h, w = targets_per_image.image_size\n            image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)\n\n            if task != 'det':\n                gt_masks = targets_per_image.gt_masks\n                padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)\n                padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks\n            else:\n                padded_masks = None\n            new_targets.append(\n                {\n                    \"labels\": targets_per_image.gt_classes,\n                    \"masks\": padded_masks,\n                    \"boxes\":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy\n                }\n            )\n        return new_targets\n\n    def semantic_inference(self, mask_cls, mask_pred):\n        # if use cross-entropy loss in training, evaluate with softmax\n        if self.semantic_ce_loss:\n            mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]\n            mask_pred = mask_pred.sigmoid()\n            semseg = torch.einsum(\"qc,qhw->chw\", mask_cls, mask_pred)\n            return semseg\n        # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp\n        # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to\n        # make the score sharper.\n        else:\n            T = self.pano_temp\n            mask_cls = mask_cls.sigmoid()\n            if self.transform_eval:\n                mask_cls = F.softmax(mask_cls / T, dim=-1)  # already sigmoid\n            mask_pred = mask_pred.sigmoid()\n            semseg = torch.einsum(\"qc,qhw->chw\", mask_cls, mask_pred)\n            return semseg\n\n    def panoptic_inference(self, mask_cls, mask_pred):\n        # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp\n        # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to\n        # make the score sharper.\n        prob = 0.5\n        T = self.pano_temp\n        scores, labels = mask_cls.sigmoid().max(-1)\n        mask_pred = mask_pred.sigmoid()\n        keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)\n        # added process\n        if self.transform_eval:\n            scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1)\n        cur_scores = scores[keep]\n        cur_classes = labels[keep]\n        cur_masks = mask_pred[keep]\n        cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks\n\n        h, w = cur_masks.shape[-2:]\n        panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)\n        segments_info = []\n\n        current_segment_id = 0\n\n        if cur_masks.shape[0] == 0:\n            # We didn't detect any mask :(\n            return panoptic_seg, segments_info\n        else:\n            # take argmax\n            cur_mask_ids = cur_prob_masks.argmax(0)\n            stuff_memory_list = {}\n            for k in range(cur_classes.shape[0]):\n                pred_class = cur_classes[k].item()\n                isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()\n                mask_area = (cur_mask_ids == k).sum().item()\n                original_area = (cur_masks[k] >= prob).sum().item()\n                mask = (cur_mask_ids == k) & (cur_masks[k] >= prob)\n\n                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:\n                    if mask_area / original_area < self.overlap_threshold:\n                        continue\n\n                    # merge stuff regions\n                    if not isthing:\n                        if int(pred_class) in stuff_memory_list.keys():\n                            panoptic_seg[mask] = stuff_memory_list[int(pred_class)]\n                            continue\n                        else:\n                            stuff_memory_list[int(pred_class)] = current_segment_id + 1\n\n                    current_segment_id += 1\n                    panoptic_seg[mask] = current_segment_id\n\n                    segments_info.append(\n                        {\n                            \"id\": current_segment_id,\n                            \"isthing\": bool(isthing),\n                            \"category_id\": int(pred_class),\n                        }\n                    )\n\n            return panoptic_seg, segments_info\n\n    def instance_inference(self, mask_cls, mask_pred, mask_box_result):\n        # mask_pred is already processed to have the same shape as original input\n        image_size = mask_pred.shape[-2:]\n        scores = mask_cls.sigmoid()  # [100, 80]\n        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)\n        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)  # select 100\n        labels_per_image = labels[topk_indices]\n        topk_indices = topk_indices // self.sem_seg_head.num_classes\n        mask_pred = mask_pred[topk_indices]\n        # if this is panoptic segmentation, we only keep the \"thing\" classes\n        if self.panoptic_on:\n            keep = torch.zeros_like(scores_per_image).bool()\n            for i, lab in enumerate(labels_per_image):\n                # print(i, len(keep), lab)\n                keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()\n            scores_per_image = scores_per_image[keep]\n            labels_per_image = labels_per_image[keep]\n            mask_pred = mask_pred[keep]\n        result = Instances(image_size)\n        # mask (before sigmoid)\n        result.pred_masks = (mask_pred > 0).float()\n        # half mask box half pred box\n        mask_box_result = mask_box_result[topk_indices]\n        if self.panoptic_on:\n            mask_box_result = mask_box_result[keep]\n        result.pred_boxes = Boxes(mask_box_result)\n        # Uncomment the following to get boxes from masks (this is slow)\n        # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()\n\n        # calculate average mask prob\n        if self.sem_seg_postprocess_before_inference:\n            mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)\n        else:\n            mask_scores_per_image = 1.0\n            # labels_per_image = labels_per_image + 1  # HACK for o365 classification\n        if self.focus_on_box:\n            mask_scores_per_image = 1.0\n        result.scores = scores_per_image * mask_scores_per_image\n        result.pred_classes = labels_per_image\n        return result\n\n    def box_postprocess(self, out_bbox, img_h, img_w):\n        # postprocess box height and width\n        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)\n        scale_fct = torch.tensor([img_w, img_h, img_w, img_h])\n        scale_fct = scale_fct.to(out_bbox)\n        boxes = boxes * scale_fct\n        return boxes\n\n    def forward_eval(self, batched_inputs, text_embeddings):\n        # import ipdb; ipdb.set_trace()\n        # print(\"Num images per batch:\",len(batched_inputs['flickr']))\n        if self.training:\n            raise NotImplementedError\n        else:\n            self.criterion.conversation=False\n            box_results, seg_results = self.forward_inner_eval(\n                batched_inputs, \n                task='seg',\n                default_text_embeddings=text_embeddings,\n            )\n            return box_results, seg_results\n\n    def forward_inner_eval(self, batched_inputs, task='seg',default_text_embeddings=None):\n        images = [x[\"image\"].to(self.device) for x in batched_inputs]\n        images = [(x - self.pixel_mean) / self.pixel_std for x in images]\n        images = ImageList.from_tensors(images, self.size_divisibility)\n        matching_threshold = batched_inputs[0][\"matching_threshold\"] if \"matching_threshold\" in batched_inputs[0].keys() else None\n\n        features = self.backbone(images.tensor)\n        # features={k:v.to(torch.bfloat16) for k,v in features.items()}\n        # mask classification target\n        if \"instances\" in batched_inputs[0]:\n            gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n            targets = self.prepare_targets(gt_instances, images, task=task)\n        else:\n            targets = None\n        default_text_embeddings_ = [default_text_embeddings[0].float(), default_text_embeddings[1]]\n        outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings_)\n        ##########eval training\n        pred_logits=outputs[\"pred_logits\"]\n        pred_boxes=outputs[\"pred_boxes\"]\n        pred_masks=outputs[\"pred_masks\"]>0\n        # scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs]\n        matched_pred_boxes = []\n        matched_pred_masks = []\n        for i in range(len(pred_logits)):\n            if len(pred_logits) > 1:\n                raise NotImplementedError\n            num_grounding = pred_logits.shape[2]\n            for gd_idx in range(num_grounding):\n                if matching_threshold is None:\n                    matched_idx = torch.argmax(pred_logits[i, :, gd_idx],dim=0)\n                    matched_boxes = pred_boxes[i][matched_idx]\n                    matched_boxes = matched_boxes[None, :]\n                else:\n                    matched_idx = torch.where(pred_logits[i, :, gd_idx].softmax(dim=0) > matching_threshold)[0]\n                    # print(matched_idx, pred_logits[i, :, gd_idx].softmax(dim=0)[matched_idx])\n                    if matched_idx.shape[0] == 0:  #* if there is no one object satisfy threshold, then select the best matched one.\n                        matched_boxes = pred_boxes.new_zeros((1, 4))\n                        matched_masks = pred_boxes.new_zeros((1, 256, 256))\n                    else:\n                        matched_boxes = pred_boxes[i][matched_idx]\n                        matched_masks = pred_masks[i][matched_idx]\n                # matched_masks=pred_masks[i][matched_idx]\n                matched_boxes_processed = []\n                for lb in range(matched_boxes.shape[0]):\n                    pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None])\n                    matched_boxes_processed.append(pred_box)\n                matched_pred_boxes.append(torch.cat(matched_boxes_processed, dim=0))\n                matched_pred_masks.append(matched_masks)\n        return matched_pred_boxes, matched_pred_masks\n    \n@register_model\ndef get_segmentation_model(cfg, **kwargs):\n    return OpenSeeD(cfg)"
  },
  {
    "path": "llava/model/openseed/architectures/openseed_model_decouple_train.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li.\n# ------------------------------------------------------------------------\nfrom typing import Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .registry import register_model\nfrom ..utils import configurable, box_ops, get_class_names\nfrom ..backbone import build_backbone, Backbone\nfrom ..body import build_openseed_head\nfrom ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion\nfrom ..language import build_language_encoder\n\nfrom detectron2.structures import Boxes, ImageList, Instances, BitMasks\nfrom detectron2.utils.memory import retry_if_cuda_oom\nfrom detectron2.data import MetadataCatalog\nimport random\nimport json\nclass OpenSeeD(nn.Module):\n    \"\"\"\n    Main class for mask classification semantic segmentation architectures.\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        *,\n        backbone: Backbone,\n        sem_seg_head: nn.Module,\n        num_queries: int,\n        object_mask_threshold: float,\n        overlap_threshold: float,\n        metadata,\n        size_divisibility: int,\n        sem_seg_postprocess_before_inference: bool,\n        pixel_mean: Tuple[float],\n        pixel_std: Tuple[float],\n        # inference\n        semantic_on: bool,\n        panoptic_on: bool,\n        instance_on: bool,\n        test_topk_per_image: int,\n        data_loader: str,\n        pano_temp: float,\n        focus_on_box: bool = False,\n        transform_eval: bool = False,\n        semantic_ce_loss: bool = False,\n        train_dataset_name: str,\n        background: bool,\n        coco_on=True,\n        coco_mask_on=True,\n        o365_on=True,\n        criterion_coco=None,\n        criterion_o365=None,\n        split_panno=False,\n    ):\n        \"\"\"\n        Args:\n            backbone: a backbone module, must follow detectron2's backbone interface\n            sem_seg_head: a module that predicts semantic segmentation from backbone features\n            criterion: a module that defines the loss\n            num_queries: int, number of queries\n            object_mask_threshold: float, threshold to filter query based on classification score\n                for panoptic segmentation inference\n            overlap_threshold: overlap threshold used in general inference for panoptic segmentation\n            metadata: dataset meta, get `thing` and `stuff` category names for panoptic\n                segmentation inference\n            size_divisibility: Some backbones require the input height and width to be divisible by a\n                specific integer. We can use this to override such requirement.\n            sem_seg_postprocess_before_inference: whether to resize the prediction back\n                to original input size before semantic segmentation inference or after.\n                For high-resolution dataset like Mapillary, resizing predictions before\n                inference will cause OOM error.\n            pixel_mean, pixel_std: list or tuple with #channels element, representing\n                the per-channel mean and std to be used to normalize the input image\n            semantic_on: bool, whether to output semantic segmentation prediction\n            instance_on: bool, whether to output instance segmentation prediction\n            panoptic_on: bool, whether to output panoptic segmentation prediction\n            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image\n        \"\"\"\n        super().__init__()\n        self.backbone = backbone\n        self.pano_temp = pano_temp\n        self.sem_seg_head = sem_seg_head\n        self.num_queries = num_queries\n        self.overlap_threshold = overlap_threshold\n        self.object_mask_threshold = object_mask_threshold\n        self.metadata = metadata\n        if size_divisibility < 0:\n            # use backbone size_divisibility if not set\n            size_divisibility = self.backbone.size_divisibility\n        self.size_divisibility = size_divisibility\n        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference\n        self.register_buffer(\"pixel_mean\", torch.Tensor(pixel_mean).view(-1, 1, 1), False)\n        self.register_buffer(\"pixel_std\", torch.Tensor(pixel_std).view(-1, 1, 1), False)\n        self.split_panno=split_panno\n        # additional args\n        self.semantic_on = semantic_on\n        self.instance_on = instance_on\n        self.panoptic_on = panoptic_on\n        self.test_topk_per_image = test_topk_per_image\n\n        self.data_loader = data_loader\n        self.focus_on_box = focus_on_box\n        self.transform_eval = transform_eval\n        self.semantic_ce_loss = semantic_ce_loss\n\n        self.train_class_names = dict()\n        self.train_dataset_name = train_dataset_name\n        self.coco_mask_on = coco_mask_on\n        self.task_switch = {'coco': coco_on, 'o365': o365_on}\n        self.criterion_coco=criterion_coco\n        self.criterion_o365=criterion_o365\n\n        print(\"self.task_switch \", self.task_switch)\n        # HACK for only two datasets for seg and det\n        if coco_on:\n            task = 'seg'\n            if not coco_mask_on:\n                task = 'det'\n            self.train_class_names[task] = get_class_names(train_dataset_name[0], background=background)\n            self.train_class_names[task] = [a.replace(\"-merged\", \"\").replace(\"-other\", \"\").replace(\"-stuff\", \"\") for a\n                                             in self.train_class_names[task]]\n            train_class_names = []\n            for name in self.train_class_names[task]:\n                names = name.split('-')\n                if len(names) > 1:\n                    assert len(names) == 2\n                    train_class_names.append(names[1] + ' ' + names[0])\n                else:\n                    train_class_names.append(name)\n            self.train_class_names[task] = train_class_names\n\n        if o365_on and len(train_dataset_name)>1:\n            for dt in train_dataset_name:\n                if \"o365\" in train_dataset_name or \"object365\" in train_dataset_name:\n                    break\n            self.train_class_names['det'] = get_class_names(dt, background=background)\n            self.train_class_names['det'] = [a.lower().split('/') for a in self.train_class_names['det']]\n\n        if not self.semantic_on:\n            assert self.sem_seg_postprocess_before_inference\n\n    @classmethod\n    def from_config(cls, cfg):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        # Loss parameters:\n        deep_supervision = dec_cfg['DEEP_SUPERVISION']\n        no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']\n\n        # loss weights\n        class_weight = dec_cfg['CLASS_WEIGHT']\n        cost_class_weight = dec_cfg['COST_CLASS_WEIGHT']\n        cost_dice_weight = dec_cfg['COST_DICE_WEIGHT']\n        dice_weight = dec_cfg['DICE_WEIGHT']\n        cost_mask_weight = dec_cfg['COST_MASK_WEIGHT']\n        mask_weight = dec_cfg['MASK_WEIGHT']\n        cost_box_weight = dec_cfg['COST_BOX_WEIGHT']\n        box_weight = dec_cfg['BOX_WEIGHT']\n        cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT']\n        giou_weight = dec_cfg['GIOU_WEIGHT']\n\n        # building matcher\n        matcher = HungarianMatcher(\n            cost_class=cost_class_weight,\n            cost_mask=cost_mask_weight,\n            cost_dice=cost_dice_weight,\n            cost_box=cost_box_weight,\n            cost_giou=cost_giou_weight,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n        )\n\n        # MaskDINO losses and weight_dict\n        weight_dict = {\"loss_mask_cls_0\": class_weight}\n        weight_dict.update({\"loss_mask_bce_0\": mask_weight, \"loss_mask_dice_0\": dice_weight})\n        weight_dict.update({\"loss_bbox_0\":box_weight,\"loss_giou_0\":giou_weight})\n        # two stage is the query selection scheme\n        if dec_cfg['TWO_STAGE']:\n            interm_weight_dict = {}\n            interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})\n            weight_dict.update(interm_weight_dict)\n        # denoising training\n        dn = dec_cfg['DN']\n        # TODO hack for dn lable loss\n        if dn == \"standard\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items() if k!=\"loss_mask\" and k!=\"loss_dice\" })\n            dn_losses=[\"dn_labels\", \"boxes\"]\n        elif dn == \"seg\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items()})\n            dn_losses=[\"labels\", \"masks\", \"boxes\"]\n        else:\n            dn_losses=[]\n        if deep_supervision:\n            dec_layers = dec_cfg['DEC_LAYERS']\n            aux_weight_dict = {}\n            for i in range(dec_layers):\n                aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()})\n            weight_dict.update(aux_weight_dict)\n        if dec_cfg['BOX']:\n            losses = [\"labels\", \"masks\",\"boxes\"]\n        else:\n            losses = [\"labels\", \"masks\"]\n\n        # update task switch\n        task_switch = {}\n        task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)})\n        top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),\n                        'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)}\n\n        # building criterion\n        criterion_coco = SetCriterion(\n            enc_cfg['NUM_CLASSES'],\n            matcher=matcher,\n            weight_dict=weight_dict,\n            top_x_layers=top_x_layers,\n            eos_coef=no_object_weight,\n            losses=losses,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n            oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],\n            importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],\n            grounding_weight=None,\n            dn=dec_cfg['DN'],\n            dn_losses=dn_losses,\n            panoptic_on=dec_cfg['PANO_BOX_LOSS'],\n            semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'],\n        )\n\n        criterion_o365 = SetCriterion(\n            enc_cfg.get('NUM_CLASSES_O365', 365),\n            matcher=matcher,\n            weight_dict=weight_dict,\n            top_x_layers=top_x_layers,\n            eos_coef=no_object_weight,\n            losses=losses,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n            oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],\n            importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],\n            grounding_weight=None,\n            dn=dec_cfg['DN'],\n            dn_losses=dn_losses,\n            panoptic_on=dec_cfg['PANO_BOX_LOSS'],\n            semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][\n                'PANOPTIC_ON'],\n        )\n\n        # build model\n        extra = {'task_switch': task_switch}\n        backbone = build_backbone(cfg)\n        lang_encoder = build_language_encoder(cfg)\n        sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)\n\n        return {\n            \"backbone\": backbone,\n            \"sem_seg_head\": sem_seg_head,\n            \"criterion_coco\": criterion_coco,\n            \"criterion_o365\": criterion_o365,\n            \"num_queries\": dec_cfg['NUM_OBJECT_QUERIES'],\n            \"object_mask_threshold\": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],\n            \"overlap_threshold\": dec_cfg['TEST']['OVERLAP_THRESHOLD'],\n            \"metadata\": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),\n            \"size_divisibility\": dec_cfg['SIZE_DIVISIBILITY'],\n            \"sem_seg_postprocess_before_inference\": (\n                dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']\n                or dec_cfg['TEST']['PANOPTIC_ON']\n                or dec_cfg['TEST']['INSTANCE_ON']\n            ),\n            \"pixel_mean\": cfg['INPUT']['PIXEL_MEAN'],\n            \"pixel_std\": cfg['INPUT']['PIXEL_STD'],\n            # inference\n            \"semantic_on\": dec_cfg['TEST']['SEMANTIC_ON'],\n            \"instance_on\": dec_cfg['TEST']['INSTANCE_ON'],\n            \"panoptic_on\": dec_cfg['TEST']['PANOPTIC_ON'],\n            \"test_topk_per_image\": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],\n            \"data_loader\": None,\n            \"focus_on_box\": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'],\n            \"transform_eval\": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'],\n            \"pano_temp\": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'],\n            \"semantic_ce_loss\": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'],\n            \"train_dataset_name\": cfg['DATASETS']['TRAIN'], # HACK for only two training set\n            \"background\": cfg['MODEL'].get('BACKGROUND', True),\n            \"coco_on\": dec_cfg.get('COCO', True),\n            \"coco_mask_on\": dec_cfg.get('COCO_MASK', True),\n            \"o365_on\": dec_cfg.get('O365', True),\n            \"split_panno\": dec_cfg.get('PANO_CRITERION', True),\n\n        }\n\n    @property\n    def device(self):\n        return self.pixel_mean.device\n\n    def forward(self, batched_inputs, inference_task='seg'):\n        # import ipdb; ipdb.set_trace()\n        if self.training:\n            losses = {}\n            if self.task_switch['coco'] and 'coco' in batched_inputs:\n                self.criterion_coco.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80\n                # self.criterion.num_classes = 133\n                task = 'seg'\n                if not self.coco_mask_on:\n                    task = 'det'\n                # import ipdb; ipdb.set_trace()\n                losses_coco = self.forward_seg(batched_inputs['coco'], task=task)\n                new_losses_coco = {}\n                for key, value in losses_coco.items():\n                    new_losses_coco['coco.'+str(key)] = losses_coco[key]\n                losses.update(new_losses_coco)\n            if self.task_switch['o365'] and 'o365' in batched_inputs:\n                self.criterion_o365.num_classes = 365\n                losses_o365 = self.forward_seg(batched_inputs['o365'], task='det')\n                new_losses_o365 = {}\n                for key, value in losses_o365.items():\n                    new_losses_o365['o365.'+str(key)] = losses_o365[key]\n                losses.update(new_losses_o365)\n            return losses\n        else:\n            processed_results = self.forward_seg(batched_inputs, task=inference_task)\n            return processed_results\n\n    def forward_seg(self, batched_inputs, task='seg'):\n        \"\"\"\n        Args:\n            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.\n                Each item in the list contains the inputs for one image.\n                For now, each item in the list is a dict that contains:\n                   * \"image\": Tensor, image in (C, H, W) format.\n                   * \"instances\": per-region ground truth\n                   * Other information that's included in the original dicts, such as:\n                     \"height\", \"width\" (int): the output resolution of the model (may be different\n                     from input resolution), used in inference.\n        Returns:\n            list[dict]:\n                each dict has the results for one image. The dict contains the following keys:\n\n                * \"sem_seg\":\n                    A Tensor that represents the\n                    per-pixel segmentation prediced by the head.\n                    The prediction has shape KxHxW that represents the logits of\n                    each class for each pixel.\n                * \"panoptic_seg\":\n                    A tuple that represent panoptic output\n                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.\n                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.\n                        Each dict contains keys \"id\", \"category_id\", \"isthing\".\n        \"\"\"\n        images = [x[\"image\"].to(self.device) for x in batched_inputs]\n        images = [(x - self.pixel_mean) / self.pixel_std for x in images]\n        images = ImageList.from_tensors(images, self.size_divisibility)\n\n        features = self.backbone(images.tensor)\n\n        if self.training:\n            if task == \"det\" and self.task_switch['o365']:\n                train_class_names = [random.sample(name, 1)[0] for name in self.train_class_names['det']]\n            else:\n                train_class_names = self.train_class_names[task]\n            self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(train_class_names, is_eval=False)\n\n            # mask classification target\n            if \"instances\" in batched_inputs[0]:\n                gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n                targets = self.prepare_targets(gt_instances, images, task=task)\n            else:\n                targets = None\n            outputs, mask_dict = self.sem_seg_head(features, targets=targets, task=task)\n            # bipartite matching-based loss\n            if task=='det':\n                criterion=self.criterion_o365\n                losses = self.criterion_o365(outputs, targets, mask_dict, task=task)\n            else:\n                criterion=self.criterion_coco\n                losses = self.criterion_coco(outputs, targets, mask_dict, task=task)\n            # else\n            for k in list(losses.keys()):\n                if k in criterion.weight_dict:\n                    losses[k] *= criterion.weight_dict[k]\n                else:\n                    # remove this loss if not specified in `weight_dict`\n                    losses.pop(k)\n            return losses\n        else:\n            outputs, _ = self.sem_seg_head(features)\n            mask_cls_results = outputs[\"pred_logits\"]\n            mask_box_results = outputs[\"pred_boxes\"]\n            if 'seg' in task:\n                if task == 'seg':\n                    self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True\n                if task == 'pan_seg':\n                    self.semantic_on = self.instance_on = False\n                    self.panoptic_on = True\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'inst_seg':\n                    self.semantic_on = self.panoptic_on = False\n                    self.instance_on = True\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'sem_pan_seg':\n                    self.semantic_on = self.panoptic_on = True\n                    self.instance_on = False\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'inst_pan_seg':\n                    self.instance_on = self.panoptic_on = True\n                    self.semantic_on = False\n                    self.sem_seg_postprocess_before_inference = True\n                if task == 'sem_seg':\n                    self.instance_on = self.panoptic_on = False\n                    self.semantic_on = True\n                    self.sem_seg_postprocess_before_inference = True\n                mask_pred_results = outputs[\"pred_masks\"]\n                # upsample masks\n                mask_pred_results = F.interpolate(\n                    mask_pred_results,\n                    size=(images.tensor.shape[-2], images.tensor.shape[-1]),\n                    mode=\"bilinear\",\n                    align_corners=False,\n                )\n\n            else:\n                self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False\n                self.instance_on = True\n                mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results)\n\n            del outputs\n\n            processed_results = []\n\n            for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip(\n                mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes\n            ):\n                height = input_per_image.get(\"height\", image_size[0])\n                width = input_per_image.get(\"width\", image_size[1])\n                processed_results.append({})\n                new_size = (images.tensor.shape[-2], images.tensor.shape[-1])  # padded size (divisible to 32)\n\n\n                if self.sem_seg_postprocess_before_inference:\n                    mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(\n                        mask_pred_result, image_size, height, width\n                    )\n                    mask_cls_result = mask_cls_result.to(mask_pred_result)\n\n                # semantic segmentation inference\n                if self.semantic_on:\n                    r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)\n                    if not self.sem_seg_postprocess_before_inference:\n                        r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)\n                    processed_results[-1][\"sem_seg\"] = r\n\n                # panoptic segmentation inference\n                if self.panoptic_on:\n                    panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)\n                    processed_results[-1][\"panoptic_seg\"] = panoptic_r\n\n                # instance segmentation inference\n\n                if self.instance_on:\n                    mask_box_result = mask_box_result.to(mask_pred_result)\n                    height = new_size[0]/image_size[0]*height\n                    width = new_size[1]/image_size[1]*width\n                    mask_box_result = self.box_postprocess(mask_box_result, height, width)\n                    instance_r = retry_if_cuda_oom(self.instance_inference)(\n                        mask_cls_result[:self.sem_seg_head.predictor.num_queries_test],\n                        mask_pred_result[:self.sem_seg_head.predictor.num_queries_test],\n                        mask_box_result[:self.sem_seg_head.predictor.num_queries_test], True)\n                    # instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result)\n                    processed_results[-1][\"instances\"] = instance_r\n            del mask_pred_results\n            return processed_results\n\n    def prepare_targets(self, targets, images, task='seg'):\n        h_pad, w_pad = images.tensor.shape[-2:]\n        new_targets = []\n        for targets_per_image in targets:\n            # pad gt\n            h, w = targets_per_image.image_size\n            image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)\n\n            if task != 'det':\n                gt_masks = targets_per_image.gt_masks\n                padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)\n                padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks\n            else:\n                padded_masks = None\n            new_targets.append(\n                {\n                    \"labels\": targets_per_image.gt_classes,\n                    \"masks\": padded_masks,\n                    \"boxes\":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy\n                }\n            )\n        return new_targets\n\n    def semantic_inference(self, mask_cls, mask_pred):\n        # if use cross-entropy loss in training, evaluate with softmax\n        if self.semantic_ce_loss:\n            mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]\n            mask_pred = mask_pred.sigmoid()\n            semseg = torch.einsum(\"qc,qhw->chw\", mask_cls, mask_pred)\n            return semseg\n        # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp\n        # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to\n        # make the score sharper.\n        else:\n            T = self.pano_temp\n            mask_cls = mask_cls.sigmoid()\n            if self.transform_eval:\n                mask_cls = F.softmax(mask_cls / T, dim=-1)  # already sigmoid\n            mask_pred = mask_pred.sigmoid()\n            semseg = torch.einsum(\"qc,qhw->chw\", mask_cls, mask_pred)\n            return semseg\n\n    def panoptic_inference(self, mask_cls, mask_pred):\n        # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp\n        # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to\n        # make the score sharper.\n        prob = 0.5\n        T = self.pano_temp\n        scores, labels = mask_cls.sigmoid().max(-1)\n        mask_pred = mask_pred.sigmoid()\n        keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)\n        # added process\n        if self.transform_eval:\n            scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1)\n        cur_scores = scores[keep]\n        cur_classes = labels[keep]\n        cur_masks = mask_pred[keep]\n        cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks\n\n        h, w = cur_masks.shape[-2:]\n        panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)\n        segments_info = []\n\n        current_segment_id = 0\n\n        if cur_masks.shape[0] == 0:\n            # We didn't detect any mask :(\n            return panoptic_seg, segments_info\n        else:\n            # take argmax\n            cur_mask_ids = cur_prob_masks.argmax(0)\n            stuff_memory_list = {}\n            for k in range(cur_classes.shape[0]):\n                pred_class = cur_classes[k].item()\n                isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()\n                mask_area = (cur_mask_ids == k).sum().item()\n                original_area = (cur_masks[k] >= prob).sum().item()\n                mask = (cur_mask_ids == k) & (cur_masks[k] >= prob)\n\n                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:\n                    if mask_area / original_area < self.overlap_threshold:\n                        continue\n\n                    # merge stuff regions\n                    if not isthing:\n                        if int(pred_class) in stuff_memory_list.keys():\n                            panoptic_seg[mask] = stuff_memory_list[int(pred_class)]\n                            continue\n                        else:\n                            stuff_memory_list[int(pred_class)] = current_segment_id + 1\n\n                    current_segment_id += 1\n                    panoptic_seg[mask] = current_segment_id\n\n                    segments_info.append(\n                        {\n                            \"id\": current_segment_id,\n                            \"isthing\": bool(isthing),\n                            \"category_id\": int(pred_class),\n                        }\n                    )\n\n            return panoptic_seg, segments_info\n\n    def instance_inference(self, mask_cls, mask_pred, mask_box_result,split_anno):\n        # mask_pred is already processed to have the same shape as original input\n        image_size = mask_pred.shape[-2:]\n        scores = mask_cls.sigmoid()  # [100, 80]\n        if split_anno:\n            labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(\n                self.sem_seg_head.predictor.num_queries_test, 1).flatten(0, 1)\n        else:\n            labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)\n        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)  # select 100\n        labels_per_image = labels[topk_indices]\n        topk_indices = topk_indices // self.sem_seg_head.num_classes\n        mask_pred = mask_pred[topk_indices]\n        # if this is panoptic segmentation, we only keep the \"thing\" classes\n        if self.panoptic_on:\n            keep = torch.zeros_like(scores_per_image).bool()\n            for i, lab in enumerate(labels_per_image):\n                # print(i, len(keep), lab)\n                keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()\n            scores_per_image = scores_per_image[keep]\n            labels_per_image = labels_per_image[keep]\n            mask_pred = mask_pred[keep]\n        result = Instances(image_size)\n        # mask (before sigmoid)\n        result.pred_masks = (mask_pred > 0).float()\n        # half mask box half pred box\n        mask_box_result = mask_box_result[topk_indices]\n        if self.panoptic_on:\n            mask_box_result = mask_box_result[keep]\n        result.pred_boxes = Boxes(mask_box_result)\n        # Uncomment the following to get boxes from masks (this is slow)\n        # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()\n\n        # calculate average mask prob\n        if self.sem_seg_postprocess_before_inference:\n            mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)\n        else:\n            mask_scores_per_image = 1.0\n            # labels_per_image = labels_per_image + 1  # HACK for o365 classification\n        if self.focus_on_box:\n            mask_scores_per_image = 1.0\n        result.scores = scores_per_image * mask_scores_per_image\n        result.pred_classes = labels_per_image\n        return result\n\n    def box_postprocess(self, out_bbox, img_h, img_w):\n        # postprocess box height and width\n        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)\n        scale_fct = torch.tensor([img_w, img_h, img_w, img_h])\n        scale_fct = scale_fct.to(out_bbox)\n        boxes = boxes * scale_fct\n        return boxes\n\n@register_model\ndef get_segmentation_model(cfg, **kwargs):\n    return OpenSeeD(cfg)"
  },
  {
    "path": "llava/model/openseed/architectures/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_model(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/openseed/backbone/__init__.py",
    "content": "from .build import build_backbone\n\nfrom .focal import *\nfrom .focal_dw import *\nfrom .swin import *\nfrom .backbone import *"
  },
  {
    "path": "llava/model/openseed/backbone/backbone.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch.nn as nn\n\nfrom detectron2.modeling import ShapeSpec\n\n# from ..layers import ShapeSpec\n\n__all__ = [\"Backbone\"]\n\n\nclass Backbone(nn.Module):\n    \"\"\"\n    Abstract base class for network backbones.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"\n        The `__init__` method of any subclass can specify its own set of arguments.\n        \"\"\"\n        super().__init__()\n\n    def forward(self):\n        \"\"\"\n        Subclasses must override this method, but adhere to the same return type.\n\n        Returns:\n            dict[str->Tensor]: mapping from feature name (e.g., \"res2\") to tensor\n        \"\"\"\n        pass\n\n    @property\n    def size_divisibility(self) -> int:\n        \"\"\"\n        Some backbones require the input height and width to be divisible by a\n        specific integer. This is typically true for encoder / decoder type networks\n        with lateral connection (e.g., FPN) for which feature maps need to match\n        dimension in the \"bottom up\" and \"top down\" paths. Set to 0 if no specific\n        input size divisibility is required.\n        \"\"\"\n        return 0\n\n    def output_shape(self):\n        \"\"\"\n        Returns:\n            dict[str->ShapeSpec]\n        \"\"\"\n        # this is a backward-compatible default\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n"
  },
  {
    "path": "llava/model/openseed/backbone/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\nfrom .backbone import *\n\ndef build_backbone(config, **kwargs):\n    model_name = config['MODEL']['BACKBONE']['NAME']\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/openseed/backbone/focal.py",
    "content": "# --------------------------------------------------------\n# FocalNet for Semantic Segmentation\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Jianwei Yang\n# --------------------------------------------------------\nimport math\nimport time\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.utils.file_io import PathManager\nfrom detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\nclass Mlp(nn.Module):\n    \"\"\" Multilayer perceptron.\"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass FocalModulation(nn.Module):\n    \"\"\" Focal Modulation\n\n    Args:\n        dim (int): Number of input channels.\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        focal_factor (int, default=2): Step to increase the focal window\n        use_postln (bool, default=False): Whether use post-modulation layernorm\n    \"\"\"\n\n    def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):\n\n        super().__init__()\n        self.dim = dim\n\n        # specific args for focalv3\n        self.focal_level = focal_level\n        self.focal_window = focal_window\n        self.focal_factor = focal_factor\n        self.use_postln_in_modulation = use_postln_in_modulation\n        self.scaling_modulator = scaling_modulator\n\n        self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)\n        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)\n\n        self.act = nn.GELU()\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.focal_layers = nn.ModuleList()\n\n        if self.use_postln_in_modulation:\n            self.ln = nn.LayerNorm(dim)\n\n        for k in range(self.focal_level):\n            kernel_size = self.focal_factor*k + self.focal_window\n            self.focal_layers.append(\n                nn.Sequential(\n                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, \n                        padding=kernel_size//2, bias=False),\n                    nn.GELU(),\n                    )\n                )\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: input features with shape of (B, H, W, C)\n        \"\"\"\n        B, nH, nW, C = x.shape\n        x = self.f(x)\n        x = x.permute(0, 3, 1, 2).contiguous()\n        q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)\n        \n        ctx_all = 0\n        for l in range(self.focal_level):                     \n            ctx = self.focal_layers[l](ctx)\n            ctx_all = ctx_all + ctx*gates[:, l:l+1]\n        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))\n        ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]\n\n        if self.scaling_modulator:\n            ctx_all = ctx_all / (self.focal_level + 1)\n\n        x_out = q * self.h(ctx_all)\n        x_out = x_out.permute(0, 2, 3, 1).contiguous()\n        if self.use_postln_in_modulation:\n            x_out = self.ln(x_out)            \n        x_out = self.proj(x_out)\n        x_out = self.proj_drop(x_out)\n        return x_out\n\nclass FocalModulationBlock(nn.Module):\n    \"\"\" Focal Modulation Block.\n\n    Args:\n        dim (int): Number of input channels.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        focal_level (int): number of focal levels\n        focal_window (int): focal kernel size at level 1\n    \"\"\"\n\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., \n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 focal_level=2, focal_window=9, \n                 use_postln=False, use_postln_in_modulation=False,\n                 scaling_modulator=False, \n                 use_layerscale=False, \n                 layerscale_value=1e-4):\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.focal_window = focal_window\n        self.focal_level = focal_level\n        self.use_postln = use_postln\n        self.use_layerscale = use_layerscale\n\n        self.norm1 = norm_layer(dim)\n        self.modulation = FocalModulation(\n            dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator\n        )            \n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.H = None\n        self.W = None\n\n        self.gamma_1 = 1.0\n        self.gamma_2 = 1.0\n        if self.use_layerscale:\n            self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        if not self.use_postln:\n            x = self.norm1(x)\n        x = x.view(B, H, W, C)\n        \n        # FM\n        x = self.modulation(x).view(B, H * W, C)\n        if self.use_postln:\n            x = self.norm1(x)\n\n        # FFN\n        x = shortcut + self.drop_path(self.gamma_1 * x)\n\n        if self.use_postln:\n            x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))\n        else:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n\n        return x\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic focal modulation layer for one stage.\n\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 mlp_ratio=4.,\n                 drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 focal_window=9, \n                 focal_level=2, \n                 use_conv_embed=False,     \n                 use_postln=False,          \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False,                   \n                 use_checkpoint=False\n        ):\n        super().__init__()\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            FocalModulationBlock(\n                dim=dim,\n                mlp_ratio=mlp_ratio,\n                drop=drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                focal_window=focal_window, \n                focal_level=focal_level, \n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation, \n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                patch_size=2,\n                in_chans=dim, embed_dim=2*dim, \n                use_conv_embed=use_conv_embed, \n                norm_layer=norm_layer, \n                is_stem=False\n            )\n\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)\n            x_down = self.downsample(x_reshaped)   \n            x_down = x_down.flatten(2).transpose(1, 2)            \n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False\n        is_stem (bool): Is the stem block or not. \n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        if use_conv_embed:\n            # if we choose to use conv embedding, then we treat the stem and non-stem differently\n            if is_stem:\n                kernel_size = 7; padding = 2; stride = 4\n            else:\n                kernel_size = 3; padding = 1; stride = 2\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)                    \n        else:\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass FocalNet(nn.Module):\n    \"\"\" FocalNet backbone.\n\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop_rate (float): Dropout rate.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        focal_levels (Sequence[int]): Number of focal levels at four stages\n        focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=1600,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dim=96,\n                 depths=[2, 2, 6, 2],\n                 mlp_ratio=4.,\n                 drop_rate=0.,\n                 drop_path_rate=0.2,\n                 norm_layer=nn.LayerNorm,\n                 patch_norm=True,\n                 out_indices=[0, 1, 2, 3],\n                 frozen_stages=-1,\n                 focal_levels=[2,2,2,2], \n                 focal_windows=[9,9,9,9],\n                 use_conv_embed=False, \n                 use_postln=False, \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False, \n                 use_checkpoint=False, \n        ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None, \n            use_conv_embed=use_conv_embed, is_stem=True)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                mlp_ratio=mlp_ratio,\n                drop=drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,\n                focal_window=focal_windows[i_layer], \n                focal_level=focal_levels[i_layer], \n                use_conv_embed=use_conv_embed,\n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation,\n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f'norm{i_layer}'\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n\n        missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]\n        logger.info(f'=> Missed keys {missed_dict}')\n        unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]\n        logger.info(f'=> Unexpected keys {unexpected_dict}')\n\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        \n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                and 'relative_position_index' not in k\n                and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():\n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    fsize1 = table_pretrained.shape[2]\n                    fsize2 = table_current.shape[2]\n\n                    # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv\n                    if fsize1 < fsize2:\n                        table_pretrained_resized = torch.zeros(table_current.shape)\n                        table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained\n                        v = table_pretrained_resized\n                    elif fsize1 > fsize2:\n                        table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]\n                        v = table_pretrained_resized\n\n\n                if (\"modulation.f\" in k or \"pre_conv\" in k): \n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    if table_pretrained.shape != table_current.shape:\n                        if len(table_pretrained.shape) == 2:\n                            dim = table_pretrained.shape[1]\n                            assert table_current.shape[1] == dim\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError\n                        elif len(table_pretrained.shape) == 1:\n                            dim = table_pretrained.shape[0]\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:dim] = table_pretrained[:dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError    \n\n                need_init_state_dict[k] = v\n        \n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        tic = time.time()\n        x = self.patch_embed(x)\n        Wh, Ww = x.size(2), x.size(3)\n\n        x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n                \n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n\n        toc = time.time()\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(FocalNet, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2FocalNet(FocalNet, Backbone):\n    def __init__(self, cfg, input_shape):\n\n        pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']\n        patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']\n        in_chans = 3\n        embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']\n        depths = cfg['BACKBONE']['FOCAL']['DEPTHS']\n        mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']\n        drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']\n        drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']\n        norm_layer = nn.LayerNorm\n        patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']\n        use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']\n        out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']\n        scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)\n\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            mlp_ratio,\n            drop_rate,\n            drop_path_rate,\n            norm_layer,\n            patch_norm,\n            out_indices,\n            focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],\n            focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],   \n            use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],    \n            use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],       \n            use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], \n            scaling_modulator=scaling_modulator,\n            use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], \n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n@register_backbone\ndef get_focal_backbone(cfg):\n    focal = D2FocalNet(cfg['MODEL'], 224)    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        logger.info(f'=> init from {filename}')\n        with PathManager.open(filename, \"rb\") as f:\n            ckpt = torch.load(f)['model']\n        focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return focal"
  },
  {
    "path": "llava/model/openseed/backbone/focal_dw.py",
    "content": "# --------------------------------------------------------\n# FocalNet for Semantic Segmentation\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Jianwei Yang\n# --------------------------------------------------------\nimport math\nimport time\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.utils.file_io import PathManager\nfrom detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\nclass Mlp(nn.Module):\n    \"\"\" Multilayer perceptron.\"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass FocalModulation(nn.Module):\n    \"\"\" Focal Modulation\n\n    Args:\n        dim (int): Number of input channels.\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        focal_factor (int, default=2): Step to increase the focal window\n        use_postln (bool, default=False): Whether use post-modulation layernorm\n    \"\"\"\n\n    def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):\n\n        super().__init__()\n        self.dim = dim\n\n        # specific args for focalv3\n        self.focal_level = focal_level\n        self.focal_window = focal_window\n        self.focal_factor = focal_factor\n        self.use_postln_in_modulation = use_postln_in_modulation\n        self.scaling_modulator = scaling_modulator\n\n        self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)\n        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)\n\n        self.act = nn.GELU()\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.focal_layers = nn.ModuleList()\n\n        if self.use_postln_in_modulation:\n            self.ln = nn.LayerNorm(dim)\n\n        for k in range(self.focal_level):\n            kernel_size = self.focal_factor*k + self.focal_window\n            self.focal_layers.append(\n                nn.Sequential(\n                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, \n                        padding=kernel_size//2, bias=False),\n                    nn.GELU(),\n                    )\n                )\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: input features with shape of (B, H, W, C)\n        \"\"\"\n        B, nH, nW, C = x.shape\n        x = self.f(x)\n        x = x.permute(0, 3, 1, 2).contiguous()\n        q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)\n        \n        ctx_all = 0\n        for l in range(self.focal_level):                     \n            ctx = self.focal_layers[l](ctx)\n            ctx_all = ctx_all + ctx*gates[:, l:l+1]\n        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))\n        ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]\n\n        if self.scaling_modulator:\n            ctx_all = ctx_all / (self.focal_level + 1)\n\n        x_out = q * self.h(ctx_all)\n        x_out = x_out.permute(0, 2, 3, 1).contiguous()\n        if self.use_postln_in_modulation:\n            x_out = self.ln(x_out)            \n        x_out = self.proj(x_out)\n        x_out = self.proj_drop(x_out)\n        return x_out\n\nclass FocalModulationBlock(nn.Module):\n    \"\"\" Focal Modulation Block.\n\n    Args:\n        dim (int): Number of input channels.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        focal_level (int): number of focal levels\n        focal_window (int): focal kernel size at level 1\n    \"\"\"\n\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., \n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 focal_level=2, focal_window=9, \n                 use_postln=False, use_postln_in_modulation=False,\n                 scaling_modulator=False, \n                 use_layerscale=False, \n                 layerscale_value=1e-4):\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.focal_window = focal_window\n        self.focal_level = focal_level\n        self.use_postln = use_postln\n        self.use_layerscale = use_layerscale\n\n        self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)\n        self.norm1 = norm_layer(dim)\n        self.modulation = FocalModulation(\n            dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator\n        )            \n\n        self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.H = None\n        self.W = None\n\n        self.gamma_1 = 1.0\n        self.gamma_2 = 1.0\n        if self.use_layerscale:\n            self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n        x = x + self.dw1(x)\n        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)\n\n        shortcut = x\n        if not self.use_postln:\n            x = self.norm1(x)\n        x = x.view(B, H, W, C)\n        \n        # FM\n        x = self.modulation(x).view(B, H * W, C)\n        x = shortcut + self.drop_path(self.gamma_1 * x)\n        if self.use_postln:\n            x = self.norm1(x)\n\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n        x = x + self.dw2(x)\n        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)\n\n        if not self.use_postln:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))        \n        else:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(x))\n            x = self.norm2(x)\n\n        return x\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic focal modulation layer for one stage.\n\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 mlp_ratio=4.,\n                 drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 focal_window=9, \n                 focal_level=2, \n                 use_conv_embed=False,     \n                 use_postln=False,          \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False,                   \n                 use_checkpoint=False, \n                 use_pre_norm=False, \n        ):\n        super().__init__()\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            FocalModulationBlock(\n                dim=dim,\n                mlp_ratio=mlp_ratio,\n                drop=drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                focal_window=focal_window, \n                focal_level=focal_level, \n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation, \n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                patch_size=2,\n                in_chans=dim, embed_dim=2*dim, \n                use_conv_embed=use_conv_embed, \n                norm_layer=norm_layer, \n                is_stem=False, \n                use_pre_norm=use_pre_norm\n            )\n\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)\n            x_down = self.downsample(x_reshaped)   \n            x_down = x_down.flatten(2).transpose(1, 2)            \n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\n# class PatchEmbed(nn.Module):\n#     r\"\"\" Image to Patch Embedding\n\n#     Args:\n#         img_size (int): Image size.  Default: 224.\n#         patch_size (int): Patch token size. Default: 4.\n#         in_chans (int): Number of input image channels. Default: 3.\n#         embed_dim (int): Number of linear projection output channels. Default: 96.\n#         norm_layer (nn.Module, optional): Normalization layer. Default: None\n#     \"\"\"\n\n#     def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, \n#         use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False):\n#         super().__init__()\n#         patch_size = to_2tuple(patch_size)\n#         patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n#         self.img_size = img_size\n#         self.patch_size = patch_size\n#         self.patches_resolution = patches_resolution\n#         self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n#         self.in_chans = in_chans\n#         self.embed_dim = embed_dim\n#         self.use_pre_norm = use_pre_norm\n\n#         if use_conv_embed:\n#             # if we choose to use conv embedding, then we treat the stem and non-stem differently\n#             if is_stem:\n#                 kernel_size = 7; padding = 3; stride = 4\n#             else:\n#                 kernel_size = 3; padding = 1; stride = 2\n#             self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)\n#         else:\n#             self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        \n#         if self.use_pre_norm:\n#             if norm_layer is not None:\n#                 self.norm = norm_layer(in_chans)\n#             else:\n#                 self.norm = None\n#         else:\n#             if norm_layer is not None:\n#                 self.norm = norm_layer(embed_dim)\n#             else:\n#                 self.norm = None\n\n#     def forward(self, x):\n#         B, C, H, W = x.shape\n#         # FIXME look at relaxing size constraints\n#         assert H == self.img_size[0] and W == self.img_size[1], \\\n#             f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        \n#         if self.use_pre_norm:\n#             if self.norm is not None:\n#                 x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n#                 x = self.norm(x).transpose(1, 2).view(B, C, H, W)\n#             x = self.proj(x).flatten(2).transpose(1, 2)\n#         else:\n#             x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n#             if self.norm is not None:\n#                 x = self.norm(x)\n#         return x\n\n#     def flops(self):\n#         Ho, Wo = self.patches_resolution\n#         flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n#         if self.norm is not None:\n#             flops += Ho * Wo * self.embed_dim\n#         return flops\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False\n        is_stem (bool): Is the stem block or not. \n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n        self.use_pre_norm = use_pre_norm\n\n        if use_conv_embed:\n            # if we choose to use conv embedding, then we treat the stem and non-stem differently\n            if is_stem:\n                kernel_size = 7; padding = 3; stride = 4\n            else:\n                kernel_size = 3; padding = 1; stride = 2\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)                    \n        else:\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        if self.use_pre_norm:\n            if norm_layer is not None:\n                self.norm = norm_layer(in_chans)\n            else:\n                self.norm = None       \n        else:\n            if norm_layer is not None:\n                self.norm = norm_layer(embed_dim)\n            else:\n                self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        B, C, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        if self.use_pre_norm:\n            if self.norm is not None:\n                x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n                x = self.norm(x).transpose(1, 2).view(B, C, H, W)\n            x = self.proj(x)\n        else:\n            x = self.proj(x)  # B C Wh Ww\n            if self.norm is not None:\n                Wh, Ww = x.size(2), x.size(3)\n                x = x.flatten(2).transpose(1, 2)\n                x = self.norm(x)\n                x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass FocalNet(nn.Module):\n    \"\"\" FocalNet backbone.\n\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop_rate (float): Dropout rate.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        focal_levels (Sequence[int]): Number of focal levels at four stages\n        focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=1600,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dim=96,\n                 depths=[2, 2, 6, 2],\n                 mlp_ratio=4.,\n                 drop_rate=0.,\n                 drop_path_rate=0.2,\n                 norm_layer=nn.LayerNorm,\n                 patch_norm=True,\n                 out_indices=[0, 1, 2, 3],\n                 frozen_stages=-1,\n                 focal_levels=[2,2,2,2], \n                 focal_windows=[9,9,9,9],\n                 use_pre_norms=[False, False, False, False], \n                 use_conv_embed=False, \n                 use_postln=False, \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False, \n                 use_checkpoint=False, \n        ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None, \n            use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                mlp_ratio=mlp_ratio,\n                drop=drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,\n                focal_window=focal_windows[i_layer], \n                focal_level=focal_levels[i_layer], \n                use_pre_norm=use_pre_norms[i_layer], \n                use_conv_embed=use_conv_embed,\n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation,\n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features        \n        # self.norm = norm_layer(num_features[-1])\n\n        # add a norm layer for each output\n        for i_layer in self.out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f'norm{i_layer}'\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n\n        missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]\n        logger.info(f'=> Missed keys {missed_dict}')\n        unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]\n        logger.info(f'=> Unexpected keys {unexpected_dict}')\n\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        \n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                and 'relative_position_index' not in k\n                and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():\n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    fsize1 = table_pretrained.shape[2]\n                    fsize2 = table_current.shape[2]\n\n                    # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv\n                    if fsize1 < fsize2:\n                        table_pretrained_resized = torch.zeros(table_current.shape)\n                        table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained\n                        v = table_pretrained_resized\n                    elif fsize1 > fsize2:\n                        table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]\n                        v = table_pretrained_resized\n\n\n                if (\"modulation.f\" in k or \"pre_conv\" in k): \n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    if table_pretrained.shape != table_current.shape:\n                        if len(table_pretrained.shape) == 2:\n                            dim = table_pretrained.shape[1]\n                            assert table_current.shape[1] == dim\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError\n                        elif len(table_pretrained.shape) == 1:\n                            dim = table_pretrained.shape[0]\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:dim] = table_pretrained[:dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError    \n\n                need_init_state_dict[k] = v\n        \n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        tic = time.time()\n        x = self.patch_embed(x)\n        Wh, Ww = x.size(2), x.size(3)\n\n        x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n                \n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n\n        toc = time.time()\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(FocalNet, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2FocalNet(FocalNet, Backbone):\n    def __init__(self, cfg, input_shape):\n\n        pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']\n        patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']\n        in_chans = 3\n        embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']\n        depths = cfg['BACKBONE']['FOCAL']['DEPTHS']\n        mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']\n        drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']\n        drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']\n        norm_layer = nn.LayerNorm\n        patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']\n        use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']\n        out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']\n        scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)\n\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            mlp_ratio,\n            drop_rate,\n            drop_path_rate,\n            norm_layer,\n            patch_norm,\n            out_indices,\n            focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],\n            focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],   \n            use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],    \n            use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],       \n            use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], \n            scaling_modulator=scaling_modulator,\n            use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], \n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n@register_backbone\ndef get_focal_backbone(cfg):\n    focal = D2FocalNet(cfg['MODEL'], 224)    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        logger.info(f'=> init from {filename}')\n        with PathManager.open(filename, \"rb\") as f:\n            ckpt = torch.load(f)['model']\n        focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return focal"
  },
  {
    "path": "llava/model/openseed/backbone/registry.py",
    "content": "_model_entrypoints = {}\n\n\ndef register_backbone(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints\n"
  },
  {
    "path": "llava/model/openseed/backbone/swin.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu, Yutong Lin, Yixuan Wei\n# --------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.modeling import Backbone, ShapeSpec\nfrom detectron2.utils.file_io import PathManager\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\n\nclass Mlp(nn.Module):\n    \"\"\"Multilayer perceptron.\"\"\"\n\n    def __init__(\n        self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    \"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        window_size,\n        num_heads,\n        qkv_bias=True,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)\n        )  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=0.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"Forward function.\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n        \n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)\n        ].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1\n        )  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1\n        ).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    \"\"\"Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        window_size=7,\n        shift_size=0,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop\n        )\n\n        self.H = None\n        self.W = None\n\n    def forward(self, x, mask_matrix):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n            mask_matrix: Attention mask for cyclic shift.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        # HACK model will not upsampling\n        # if min([H, W]) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            # self.shift_size = 0\n            # self.window_size = min([H,W])\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_l = pad_t = 0\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n        _, Hp, Wp, _ = x.shape\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n            attn_mask = mask_matrix\n        else:\n            shifted_x = x\n            attn_mask = None\n\n        # partition windows\n        x_windows = window_partition(\n            shifted_x, self.window_size\n        )  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(\n            -1, self.window_size * self.window_size, C\n        )  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b > 0:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchMerging(nn.Module):\n    \"\"\"Patch Merging Layer\n    Args:\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C)\n\n        # padding\n        pad_input = (H % 2 == 1) or (W % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        num_heads (int): Number of attention head.\n        window_size (int): Local window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        depth,\n        num_heads,\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        norm_layer=nn.LayerNorm,\n        downsample=None,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n        self.window_size = window_size\n        self.shift_size = window_size // 2\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList(\n            [\n                SwinTransformerBlock(\n                    dim=dim,\n                    num_heads=num_heads,\n                    window_size=window_size,\n                    shift_size=0 if (i % 2 == 0) else window_size // 2,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n\n        # calculate attention mask for SW-MSA\n        Hp = int(np.ceil(H / self.window_size)) * self.window_size\n        Wp = int(np.ceil(W / self.window_size)) * self.window_size\n        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1\n        h_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        w_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(\n            img_mask, self.window_size\n        )  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(\n            attn_mask == 0, float(0.0)\n        ).type(x.dtype)\n        \n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, attn_mask)\n            else:\n                x = blk(x, attn_mask)\n        if self.downsample is not None:\n            x_down = self.downsample(x, H, W)\n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Image to Patch Embedding\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        # padding\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass SwinTransformer(nn.Module):\n    \"\"\"Swin Transformer backbone.\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        num_heads (tuple[int]): Number of attention head of each stage.\n        window_size (int): Window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n        drop_rate (float): Dropout rate.\n        attn_drop_rate (float): Attention dropout rate. Default: 0.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        pretrain_img_size=224,\n        patch_size=4,\n        in_chans=3,\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n        norm_layer=nn.LayerNorm,\n        ape=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=-1,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,\n        )\n\n        # absolute position embedding\n        if self.ape:\n            pretrain_img_size = to_2tuple(pretrain_img_size)\n            patch_size = to_2tuple(patch_size)\n            patches_resolution = [\n                pretrain_img_size[0] // patch_size[0],\n                pretrain_img_size[1] // patch_size[1],\n            ]\n\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])\n            )\n            trunc_normal_(self.absolute_pos_embed, std=0.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint,\n            )\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f\"norm{i_layer}\"\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 1 and self.ape:\n            self.absolute_pos_embed.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=0.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                    (\n                            k.split('.')[0] in pretrained_layers\n                            or pretrained_layers[0] == '*'\n                    )\n                    and 'relative_position_index' not in k\n                    and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():\n                    relative_position_bias_table_pretrained = v\n                    relative_position_bias_table_current = model_dict[k]\n                    L1, nH1 = relative_position_bias_table_pretrained.size()\n                    L2, nH2 = relative_position_bias_table_current.size()\n                    if nH1 != nH2:\n                        logger.info(f\"Error in loading {k}, passing\")\n                    else:\n                        if L1 != L2:\n                            logger.info(\n                                '=> load_pretrained: resized variant: {} to {}'\n                                    .format((L1, nH1), (L2, nH2))\n                            )\n                            S1 = int(L1 ** 0.5)\n                            S2 = int(L2 ** 0.5)\n                            relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n                                relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                                size=(S2, S2),\n                                mode='bicubic')\n                            v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n                if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():\n                    absolute_pos_embed_pretrained = v\n                    absolute_pos_embed_current = model_dict[k]\n                    _, L1, C1 = absolute_pos_embed_pretrained.size()\n                    _, L2, C2 = absolute_pos_embed_current.size()\n                    if C1 != C1:\n                        logger.info(f\"Error in loading {k}, passing\")\n                    else:\n                        if L1 != L2:\n                            logger.info(\n                                '=> load_pretrained: resized variant: {} to {}'\n                                    .format((1, L1, C1), (1, L2, C2))\n                            )\n                            S1 = int(L1 ** 0.5)\n                            S2 = int(L2 ** 0.5)\n                            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)\n                            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)\n                            absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(\n                                absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')\n                            v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)\n\n                need_init_state_dict[k] = v\n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        x = self.patch_embed(x)\n\n        Wh, Ww = x.size(2), x.size(3)\n        if self.ape:\n            # interpolate the position embedding to the corresponding size\n            absolute_pos_embed = F.interpolate(\n                self.absolute_pos_embed, size=(Wh, Ww), mode=\"bicubic\"\n            )\n            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C\n        else:\n            x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n\n            if i in self.out_indices:\n                norm_layer = getattr(self, f\"norm{i}\")\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n\n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n        \n\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(SwinTransformer, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2SwinTransformer(SwinTransformer, Backbone):\n    def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, \n                 depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale,\n                 drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, \n                 patch_norm, out_indices, use_checkpoint):\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            num_heads,\n            window_size,\n            mlp_ratio,\n            qkv_bias,\n            qk_scale,\n            drop_rate,\n            attn_drop_rate,\n            drop_path_rate,\n            norm_layer,\n            ape,\n            patch_norm,\n            out_indices,\n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features))\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in feature_names\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n\n@register_backbone\ndef get_swin_backbone(cfg):\n    swin_cfg = cfg['MODEL']['BACKBONE']['SWIN']\n\n    pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE']\n    patch_size = swin_cfg['PATCH_SIZE']\n    in_chans = 3\n    embed_dim = swin_cfg['EMBED_DIM']\n    depths = swin_cfg['DEPTHS']\n    num_heads = swin_cfg['NUM_HEADS']\n    window_size = swin_cfg['WINDOW_SIZE']\n    mlp_ratio = swin_cfg['MLP_RATIO']\n    qkv_bias = swin_cfg['QKV_BIAS']\n    qk_scale = swin_cfg['QK_SCALE']\n    drop_rate = swin_cfg['DROP_RATE']\n    attn_drop_rate = swin_cfg['ATTN_DROP_RATE']\n    drop_path_rate = swin_cfg['DROP_PATH_RATE']\n    norm_layer = nn.LayerNorm\n    ape = swin_cfg['APE']\n    patch_norm = swin_cfg['PATCH_NORM']\n    use_checkpoint = swin_cfg['USE_CHECKPOINT']\n    out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3])\n    \n    swin = D2SwinTransformer(\n        swin_cfg,\n        pretrain_img_size,\n        patch_size,\n        in_chans,\n        embed_dim,\n        depths,\n        num_heads,\n        window_size,\n        mlp_ratio,\n        qkv_bias,\n        qk_scale,\n        drop_rate,\n        attn_drop_rate,\n        drop_path_rate,\n        norm_layer,\n        ape,\n        patch_norm,\n        out_indices,\n        use_checkpoint=use_checkpoint,\n    )    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        with PathManager.open(filename, \"rb\") as f:\n            ckpt = torch.load(f, map_location='cpu')['model']\n        swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return swin"
  },
  {
    "path": "llava/model/openseed/body/__init__.py",
    "content": "from .build import build_openseed_head"
  },
  {
    "path": "llava/model/openseed/body/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\nfrom .openseed_head import *\n\n\ndef build_openseed_head(config, *args, **kwargs):\n    model_name = config['MODEL']['HEAD']\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    body = model_entrypoints(model_name)(config, *args, **kwargs)\n    return body"
  },
  {
    "path": "llava/model/openseed/body/decoder/__init__.py",
    "content": "from .build import build_decoder\nfrom .openseed_decoder import *\nfrom .openseed_decoder_decouple import *"
  },
  {
    "path": "llava/model/openseed/body/decoder/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\n\ndef build_decoder(config, *args, **kwargs):\n    model_name = config['MODEL']['DECODER']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, *args, **kwargs)"
  },
  {
    "path": "llava/model/openseed/body/decoder/modules.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\nfrom detectron2.layers import Conv2d\nimport fvcore.nn.weight_init as weight_init\n\n\nclass SelfAttentionLayer(nn.Module):\n\n    def __init__(self, d_model, nhead, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt,\n                     tgt_mask: Optional[Tensor] = None,\n                     tgt_key_padding_mask: Optional[Tensor] = None,\n                     query_pos: Optional[Tensor] = None):\n        q = k = self.with_pos_embed(tgt, query_pos)\n        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,\n                              key_padding_mask=tgt_key_padding_mask)[0]\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n\n        return tgt\n\n    def forward_pre(self, tgt,\n                    tgt_mask: Optional[Tensor] = None,\n                    tgt_key_padding_mask: Optional[Tensor] = None,\n                    query_pos: Optional[Tensor] = None):\n        tgt2 = self.norm(tgt)\n        q = k = self.with_pos_embed(tgt2, query_pos)\n        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,\n                              key_padding_mask=tgt_key_padding_mask)[0]\n        tgt = tgt + self.dropout(tgt2)\n        \n        return tgt\n\n    def forward(self, tgt,\n                tgt_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                query_pos: Optional[Tensor] = None):\n        if self.normalize_before:\n            return self.forward_pre(tgt, tgt_mask,\n                                    tgt_key_padding_mask, query_pos)\n        return self.forward_post(tgt, tgt_mask,\n                                 tgt_key_padding_mask, query_pos)\n\n\nclass CrossAttentionLayer(nn.Module):\n\n    def __init__(self, d_model, nhead, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt, memory,\n                     memory_mask: Optional[Tensor] = None,\n                     memory_key_padding_mask: Optional[Tensor] = None,\n                     pos: Optional[Tensor] = None,\n                     query_pos: Optional[Tensor] = None):\n        tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),\n                                   key=self.with_pos_embed(memory, pos),\n                                   value=memory, attn_mask=memory_mask,\n                                   key_padding_mask=memory_key_padding_mask)\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n        return tgt, avg_attn\n\n    def forward_pre(self, tgt, memory,\n                    memory_mask: Optional[Tensor] = None,\n                    memory_key_padding_mask: Optional[Tensor] = None,\n                    pos: Optional[Tensor] = None,\n                    query_pos: Optional[Tensor] = None):\n        tgt2 = self.norm(tgt)\n        tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),\n                                   key=self.with_pos_embed(memory, pos),\n                                   value=memory, attn_mask=memory_mask,\n                                   key_padding_mask=memory_key_padding_mask)\n        tgt = tgt + self.dropout(tgt2)\n\n        return tgt, avg_attn\n\n    def forward(self, tgt, memory,\n                memory_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None,\n                query_pos: Optional[Tensor] = None):\n        if self.normalize_before:\n            return self.forward_pre(tgt, memory, memory_mask,\n                                    memory_key_padding_mask, pos, query_pos)\n        return self.forward_post(tgt, memory, memory_mask,\n                                 memory_key_padding_mask, pos, query_pos)\n\n\nclass FFNLayer(nn.Module):\n\n    def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm = nn.LayerNorm(d_model)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt):\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n        return tgt\n\n    def forward_pre(self, tgt):\n        tgt2 = self.norm(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n        tgt = tgt + self.dropout(tgt2)\n        return tgt\n\n    def forward(self, tgt):\n        if self.normalize_before:\n            return self.forward_pre(tgt)\n        return self.forward_post(tgt)\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\n\nclass MLP(nn.Module):\n    \"\"\" Very simple multi-layer perceptron (also called FFN)\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n\n    def forward(self, x):\n        for i, layer in enumerate(self.layers):\n            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        return x\n"
  },
  {
    "path": "llava/model/openseed/body/decoder/openseed_decoder.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\nimport logging\nimport fvcore.nn.weight_init as weight_init\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom detectron2.layers import Conv2d\nfrom detectron2.utils.registry import Registry\nfrom detectron2.structures import BitMasks\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_decoder\nfrom .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer\nfrom .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid\nfrom ...utils import box_ops\nfrom ...utils import configurable\n\n\nclass OpenSeeDDecoder(nn.Module):\n    @configurable\n    def __init__(\n            self,\n            # lang_encoder: nn.Module,\n            in_channels,\n            mask_classification=True,\n            *,\n            num_classes: int,\n            hidden_dim: int,\n            dim_proj: int,\n            num_queries: int,\n            nheads: int,\n            dim_feedforward: int,\n            dec_layers: int,\n            mask_dim: int,\n            enforce_input_project: bool,\n            two_stage: bool,\n            dn: str,\n            noise_scale:float,\n            dn_num:int,\n            initialize_box_type:bool,\n            initial_pred:bool,\n            learn_tgt: bool,\n            total_num_feature_levels: int = 4,\n            dropout: float = 0.0,\n            activation: str = 'relu',\n            nhead: int = 8,\n            dec_n_points: int = 4,\n            return_intermediate_dec: bool = True,\n            query_dim: int = 4,\n            dec_layer_share: bool = False,\n            semantic_ce_loss: bool = False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            in_channels: channels of the input features\n            mask_classification: whether to add mask classifier or not\n            num_classes: number of classes\n            hidden_dim: Transformer feature dimension\n            num_queries: number of queries\n            nheads: number of heads\n            dim_feedforward: feature dimension in feedforward network\n            enc_layers: number of Transformer encoder layers\n            dec_layers: number of Transformer decoder layers\n            pre_norm: whether to use pre-LayerNorm or not\n            mask_dim: mask feature dimension\n            enforce_input_project: add input project 1x1 conv even if input\n                channels and hidden dim is identical\n            d_model: transformer dimension\n            dropout: dropout rate\n            activation: activation function\n            nhead: num heads in multi-head attention\n            dec_n_points: number of sampling points in decoder\n            return_intermediate_dec: return the intermediate results of decoder\n            query_dim: 4 -> (x, y, w, h)\n            dec_layer_share: whether to share each decoder layer\n            semantic_ce_loss: use ce loss for semantic segmentation\n        \"\"\"\n        super().__init__()\n\n        assert mask_classification, \"Only support mask classification model\"\n        self.mask_classification = mask_classification\n        self.num_feature_levels = total_num_feature_levels\n        self.initial_pred = initial_pred\n\n        # define Transformer decoder here\n        self.dn=dn\n        self.learn_tgt = learn_tgt\n        self.noise_scale=noise_scale\n        self.dn_num=dn_num\n        self.num_heads = nheads\n        self.num_layers = dec_layers\n        self.two_stage=two_stage\n        self.initialize_box_type = initialize_box_type\n        self.total_num_feature_levels = total_num_feature_levels\n\n        self.num_queries = num_queries\n        self.semantic_ce_loss = semantic_ce_loss\n        # learnable query features\n        if not two_stage or self.learn_tgt:\n            self.query_feat = nn.Embedding(num_queries, hidden_dim)\n        if not two_stage and initialize_box_type == 'no':\n            self.query_embed = nn.Embedding(num_queries, 4)\n        if two_stage:\n            self.enc_output = nn.Linear(hidden_dim, hidden_dim)\n            self.enc_output_norm = nn.LayerNorm(hidden_dim)\n\n        self.input_proj = nn.ModuleList()\n        for _ in range(self.num_feature_levels):\n            if in_channels != hidden_dim or enforce_input_project:\n                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))\n                weight_init.c2_xavier_fill(self.input_proj[-1])\n            else:\n                self.input_proj.append(nn.Sequential())\n        self.num_classes=num_classes\n        # output FFNs\n        assert self.mask_classification, \"why not class embedding?\"\n        # self.label_enc=nn.Embedding(505, hidden_dim)  # this is a hack for o365+coco (365+133=498)\n        self.dim_proj = dim_proj\n        # self.lang_encoder = lang_encoder\n        self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim))\n        trunc_normal_(self.lang_mapper, std=.02)\n        self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))\n        trunc_normal_(self.class_embed, std=.02)\n\n        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)\n\n        # init decoder\n        self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim)\n        decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward,\n                                                          dropout, activation,\n                                                          self.num_feature_levels, nhead, dec_n_points)\n        self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm,\n                                          return_intermediate=return_intermediate_dec,\n                                          d_model=hidden_dim, query_dim=query_dim,\n                                          num_feature_levels=self.num_feature_levels,\n                                          dec_layer_share=dec_layer_share,\n                                          )\n\n        self.hidden_dim = hidden_dim\n        self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)\n        nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)\n        nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)\n        box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)]  # share box prediction each layer\n        self.bbox_embed = nn.ModuleList(box_embed_layerlist)\n        self.decoder.bbox_embed = self.bbox_embed\n        self.logit_scale = nn.Parameter(torch.ones([]))\n        self.default_text_embeddings = None #for grounding tokens\n        self.default_text_embeddings_mask = None #for grounding tokens\n\n    @classmethod\n    def from_config(cls, cfg, in_channels, mask_classification, extra):\n        ret = {}\n        ret[\"in_channels\"] = in_channels\n        # ret[\"lang_encoder\"] = lang_encoder\n        ret[\"mask_classification\"] = mask_classification\n\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret[\"num_classes\"] = enc_cfg['NUM_CLASSES']\n        ret[\"hidden_dim\"] = dec_cfg['HIDDEN_DIM']\n        ret[\"dim_proj\"] = cfg['MODEL']['DIM_PROJ']\n        ret[\"num_queries\"] = dec_cfg['NUM_OBJECT_QUERIES']\n\n        # Transformer parameters:\n        ret[\"nheads\"] = dec_cfg['NHEADS']\n        ret[\"dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']\n        ret[\"dec_layers\"] = dec_cfg['DEC_LAYERS']\n        ret[\"enforce_input_project\"] = dec_cfg['ENFORCE_INPUT_PROJ']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"two_stage\"] = dec_cfg['TWO_STAGE']\n        ret[\"initialize_box_type\"] = dec_cfg['INITIALIZE_BOX_TYPE']  # ['no', 'bitmask', 'mask2box']\n        ret[\"dn\"] = dec_cfg['DN']\n        ret[\"noise_scale\"] = dec_cfg['DN_NOISE_SCALE']\n        ret[\"dn_num\"] = dec_cfg['DN_NUM']\n        ret[\"initial_pred\"] = dec_cfg['INITIAL_PRED']\n        ret[\"learn_tgt\"] = dec_cfg['LEARN_TGT']\n        ret[\"total_num_feature_levels\"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS']\n        ret[\"semantic_ce_loss\"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON']\n\n        return ret\n\n    def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size):\n        \"\"\"\n        modified from dn-detr. You can refer to dn-detr\n        https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py\n        for more details\n            :param dn_args: scalar, noise_scale\n            :param tgt: original tgt (content) in the matching part\n            :param refpoint_emb: positional anchor queries in the matching part\n            :param batch_size: bs\n            \"\"\"\n        if self.training:\n            scalar, noise_scale = self.dn_num, self.noise_scale\n\n            known = [(torch.ones_like(t['labels'])).cuda() for t in targets]\n            know_idx = [torch.nonzero(t) for t in known]\n            known_num = [sum(k) for k in known]\n\n            # use fix number of dn queries\n            if max(known_num) > 0:\n                scalar = scalar // (int(max(known_num)))\n            else:\n                scalar = 0\n            if scalar == 0:\n                input_query_label = None\n                input_query_bbox = None\n                attn_mask = None\n                mask_dict = None\n                return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n            # can be modified to selectively denosie some label or boxes; also known label prediction\n            unmask_bbox = unmask_label = torch.cat(known)\n            labels = torch.cat([t['labels'] for t in targets])\n            # use languge as denosing content queries.\n            # if task == 'det':\n            #     labels = labels  # o365 start from 133 class\n            boxes = torch.cat([t['boxes'] for t in targets])\n            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])\n            # known\n            known_indice = torch.nonzero(unmask_label + unmask_bbox)\n            known_indice = known_indice.view(-1)\n\n            # noise\n            known_indice = known_indice.repeat(scalar, 1).view(-1)\n            known_labels = labels.repeat(scalar, 1).view(-1)\n            known_bid = batch_idx.repeat(scalar, 1).view(-1)\n            known_bboxs = boxes.repeat(scalar, 1)\n            known_labels_expaned = known_labels.clone()\n            known_bbox_expand = known_bboxs.clone()\n\n            if noise_scale > 0:\n                diff = torch.zeros_like(known_bbox_expand)\n                diff[:, :2] = known_bbox_expand[:, 2:] / 2\n                diff[:, 2:] = known_bbox_expand[:, 2:]\n                known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),\n                                               diff).cuda() * noise_scale\n                known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)\n\n            m = known_labels_expaned.long().to('cuda')\n            # import ipdb; ipdb.set_trace()\n            input_label_embed = torch.gather(self.default_text_embeddings, 0,\n                                             m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper\n\n            input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n            single_pad = int(max(known_num))\n            pad_size = int(single_pad * scalar)\n\n            padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim)\n            padding_bbox = input_bbox_embed.new_zeros(pad_size, 4)\n\n            if not refpoint_emb is None:\n                input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1)\n                input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1)\n            else:\n                input_query_label = padding_label.repeat(batch_size, 1, 1)\n                input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)\n\n            # map\n            map_known_indice = input_label_embed.new_tensor([])\n            if len(known_num):\n                map_known_indice = torch.cat(\n                    [input_label_embed.new_tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]\n                map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long()\n            if len(known_bid):\n                input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed\n                input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed\n\n            tgt_size = pad_size + self.num_queries\n            attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0\n            # match query cannot see the reconstruct\n            attn_mask[pad_size:, :pad_size] = True\n            # reconstruct cannot see each other\n            for i in range(scalar):\n                if i == 0:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                if i == scalar - 1:\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n                else:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n            mask_dict = {\n                'known_indice': torch.as_tensor(known_indice).long(),\n                'batch_idx': torch.as_tensor(batch_idx).long(),\n                'map_known_indice': torch.as_tensor(map_known_indice).long(),\n                'known_lbs_bboxes': (known_labels, known_bboxs),\n                'know_idx': know_idx,\n                'pad_size': pad_size,\n                'scalar': scalar,\n            }\n        else:\n            if not refpoint_emb is None:\n                input_query_label = tgt.repeat(batch_size, 1, 1)\n                input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1)\n            else:\n                input_query_label = None\n                input_query_bbox = None\n            attn_mask = None\n            mask_dict = None\n\n        # 100*batch*256\n        if not input_query_bbox is None:\n            input_query_label = input_query_label\n            input_query_bbox = input_query_bbox\n\n        return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n    def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask):\n        \"\"\"\n            post process of dn after output from the transformer\n            put the dn part in the mask_dict\n            \"\"\"\n        assert mask_dict['pad_size'] > 0\n        output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :]\n        outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :]\n        output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :]\n        outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :]\n        output_known_mask = None\n        if outputs_mask is not None:\n            output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :]\n            outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :]\n        out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]}\n\n        out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord)\n        mask_dict['output_known_lbs_bboxes']=out\n        return outputs_class, outputs_coord, outputs_mask\n\n    def get_valid_ratio(self, mask):\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def pred_box(self, reference, hs, ref0=None):\n        \"\"\"\n        :param reference: reference box coordinates from each decoder layer\n        :param hs: content\n        :param ref0: whether there are prediction from the first layer\n        \"\"\"\n        if ref0 is None:\n            outputs_coord_list = []\n        else:\n            outputs_coord_list = [ref0]\n        for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):\n            layer_delta_unsig = layer_bbox_embed(layer_hs)\n            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)\n            layer_outputs_unsig = layer_outputs_unsig.sigmoid()\n            outputs_coord_list.append(layer_outputs_unsig)\n        outputs_coord_list = torch.stack(outputs_coord_list)\n        return outputs_coord_list\n\n    def compute_similarity(self, v_emb,name='default'):\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = self.default_text_embeddings\n        output = self.logit_scale.exp() * v_emb @ t_emb.transpose(1, 2)\n        output[~self.default_text_embeddings_mask[:,None].repeat(1,output.shape[1],1)] = -100.\n        # output = v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n    def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg',default_text_embeddings=None, extra={}):\n        \"\"\"\n        task: seg/det\n        \"\"\"\n        self.default_text_embeddings,self.default_text_embeddings_mask=default_text_embeddings\n        self.dn=\"no\"\n        assert len(x) == self.num_feature_levels\n        do_seg = (task != 'det')   # if task is det, not do segmentation training\n        size_list = []\n        # disable mask, it does not affect performance\n        enable_mask = 0\n        if masks is not None:\n            for src in x:\n                if src.size(2) % 32 or src.size(3) % 32:\n                    enable_mask = 1\n        if enable_mask == 0:\n            masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x]\n        src_flatten = []\n        mask_flatten = []\n        spatial_shapes = []\n        for i in range(self.num_feature_levels):\n            idx=self.num_feature_levels-1-i\n            bs, c , h, w=x[idx].shape\n            size_list.append(x[i].shape[-2:])\n            spatial_shapes.append(x[idx].shape[-2:])\n            src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2))\n            mask_flatten.append(masks[i].flatten(1))\n        src_flatten = torch.cat(src_flatten, 1)  # bs, \\sum{hxw}, c\n        mask_flatten = torch.cat(mask_flatten, 1)  # bs, \\sum{hxw}\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        predictions_class = []\n        predictions_mask = []\n        if self.two_stage:\n            output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes)\n            output_memory = self.enc_output_norm(self.enc_output(output_memory))\n            output_memory_ = output_memory @ self.class_embed\n            enc_outputs_class_unselected = self.compute_similarity(output_memory_,default_text_embeddings)\n            enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float(\"-inf\")\n            # enc_outputs_class_unselected = self.class_embed(output_memory)\n            enc_outputs_coord_unselected = self._bbox_embed(\n                output_memory) + output_proposals  # (bs, \\sum{hw}, 4) unsigmoid\n            topk = self.num_queries\n            topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1]\n            refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1,\n                                                   topk_proposals.unsqueeze(-1).repeat(1, 1, 4))  # unsigmoid\n            refpoint_embed = refpoint_embed_undetach.detach()\n\n            tgt_undetach = torch.gather(output_memory, 1,\n                                  topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim))  # unsigmoid\n            outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg)\n            tgt = tgt_undetach.detach()\n            if self.learn_tgt:\n                tgt = self.query_feat.weight[None].repeat(bs, 1, 1)\n            interm_outputs=dict()\n            interm_outputs['pred_logits'] = outputs_class\n            interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid()\n            interm_outputs['pred_masks'] = outputs_mask\n\n            if self.initialize_box_type != 'no' and do_seg:\n                # convert masks into boxes to better initialize box in the decoder\n                assert self.initial_pred\n                flaten_mask = outputs_mask.detach().flatten(0, 1)\n                h, w = outputs_mask.shape[-2:]\n                if self.initialize_box_type == 'bitmask':  # slower, but more accurate\n                    refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda()\n                elif self.initialize_box_type == 'mask2box':  # faster conversion\n                    refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda()\n                else:\n                    assert NotImplementedError\n                refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h],\n                                                                                              dtype=torch.float).cuda()\n                refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4)\n                refpoint_embed = inverse_sigmoid(refpoint_embed)\n        elif not self.two_stage:\n            tgt = self.query_feat.weight[None].repeat(bs, 1, 1)\n            refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1)\n\n        tgt_mask = None\n        mask_dict = None\n        if self.dn != \"no\" and self.training:\n            assert targets is not None\n            input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n                self.prepare_for_dn(targets, None, None, x[0].shape[0])\n            if mask_dict is not None:\n                tgt=torch.cat([input_query_label, tgt],dim=1)\n\n        # direct prediction from the matching and denoising part in the begining\n        if self.initial_pred:\n            outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg)\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n        if self.dn != \"no\" and self.training and mask_dict is not None:\n            refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1)\n        # print('tgt',tgt.dtype)\n        # print('src_flatten',src_flatten.dtype)\n        # print('refpoint',refpoint_embed.dtype)\n        tgt=tgt.to(src_flatten.dtype)\n        refpoint_embed=refpoint_embed.to(src_flatten.dtype)\n        hs, references = self.decoder(\n            tgt=tgt.transpose(0, 1),\n            memory=src_flatten.transpose(0, 1),\n            memory_key_padding_mask=mask_flatten,\n            pos=None,\n            refpoints_unsigmoid=refpoint_embed.transpose(0, 1),\n            level_start_index=level_start_index,\n            spatial_shapes=spatial_shapes,\n            valid_ratios=valid_ratios,\n            tgt_mask=tgt_mask\n        )\n\n        for i, output in enumerate(hs):\n            outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg)\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n\n        # iteratively box prediction\n        if self.initial_pred:\n            out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid())\n            assert len(predictions_class) == self.num_layers + 1\n        else:\n            out_boxes = self.pred_box(references, hs)\n        if mask_dict is not None:\n            predictions_mask = None if not do_seg else torch.stack(predictions_mask)\n            predictions_class =torch.stack(predictions_class)\n            predictions_class, out_boxes,predictions_mask=\\\n                self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask)\n            predictions_class = list(predictions_class)\n\n            if predictions_mask is None:\n                predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n                for i in range(self.mask_embed.num_layers):\n                    predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0])  # avoid no mask loss\n                predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n            if do_seg:\n                predictions_mask = list(predictions_mask)\n        elif self.training:  # this is to insure self.label_enc participate in the model\n            predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n            for i in range(self.mask_embed.num_layers):\n                predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                            self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                        0])  # avoid no mask loss\n            predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n        out = {\n            'pred_logits': predictions_class[-1],\n            'pred_masks': None if not do_seg else predictions_mask[-1],\n            'pred_boxes':out_boxes[-1],\n            'aux_outputs': self._set_aux_loss(\n                predictions_class if self.mask_classification else None, predictions_mask,out_boxes\n            )\n        }\n        if self.two_stage:\n            out['interm_outputs'] = interm_outputs\n\n        return out, mask_dict\n\n    def forward_prediction_heads(self, output, mask_features, pred_mask=True):\n        decoder_output = self.decoder_norm(output)\n        decoder_output = decoder_output.transpose(0, 1)\n\n        class_embed = decoder_output @ self.class_embed\n        outputs_class = self.compute_similarity(class_embed,self.default_text_embeddings)\n\n        outputs_mask = None\n        if pred_mask:\n            mask_embed = self.mask_embed(decoder_output)\n            outputs_mask = torch.einsum(\"bqc,bchw->bqhw\", mask_embed, mask_features)\n\n        return outputs_class, outputs_mask\n\n    @torch.jit.unused\n    def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None):\n        # this is a workaround to make torchscript happy, as torchscript\n        # doesn't support dictionary with non-homogeneous values, such\n        # as a dict having both a Tensor and a list.\n        # if self.mask_classification:\n        if out_boxes is None:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b}\n                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])\n            ]\n        elif outputs_seg_masks is None:\n            return [\n                {\"pred_logits\": a, \"pred_boxes\": c}\n                for a, c in zip(outputs_class[:-1], out_boxes[:-1])\n            ]\n        else:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b, \"pred_boxes\":c}\n                for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1])\n            ]\n\n@register_decoder\ndef get_maskdino_transformer_decoder(cfg, in_channels, mask_classification, extra):\n    return OpenSeeDDecoder(cfg, in_channels, mask_classification, extra)\n"
  },
  {
    "path": "llava/model/openseed/body/decoder/openseed_decoder_decouple.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang.\nimport logging\nimport fvcore.nn.weight_init as weight_init\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom detectron2.layers import Conv2d\nfrom detectron2.utils.registry import Registry\nfrom detectron2.structures import BitMasks\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_decoder\nfrom .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer\nfrom .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid\nfrom ...utils import box_ops\nfrom ...utils import configurable\n\n\nclass MaskDINODecoder(nn.Module):\n    @configurable\n    def __init__(\n            self,\n            lang_encoder: nn.Module,\n            in_channels,\n            mask_classification=True,\n            *,\n            num_classes: int,\n            hidden_dim: int,\n            dim_proj: int,\n            num_queries: int,\n            nheads: int,\n            dim_feedforward: int,\n            dec_layers: int,\n            mask_dim: int,\n            enforce_input_project: bool,\n            two_stage: bool,\n            dn: str,\n            noise_scale:float,\n            dn_num:int,\n            initialize_box_type:bool,\n            initial_pred:bool,\n            learn_tgt: bool,\n            total_num_feature_levels: int = 4,\n            dropout: float = 0.0,\n            activation: str = 'relu',\n            nhead: int = 8,\n            dec_n_points: int = 4,\n            return_intermediate_dec: bool = True,\n            query_dim: int = 4,\n            dec_layer_share: bool = False,\n            semantic_ce_loss: bool = False,\n            no_update=False,\n            num_queries_stuff=100,\n            num_queries_test=300,\n\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            in_channels: channels of the input features\n            mask_classification: whether to add mask classifier or not\n            num_classes: number of classes\n            hidden_dim: Transformer feature dimension\n            num_queries: number of queries\n            nheads: number of heads\n            dim_feedforward: feature dimension in feedforward network\n            enc_layers: number of Transformer encoder layers\n            dec_layers: number of Transformer decoder layers\n            pre_norm: whether to use pre-LayerNorm or not\n            mask_dim: mask feature dimension\n            enforce_input_project: add input project 1x1 conv even if input\n                channels and hidden dim is identical\n            d_model: transformer dimension\n            dropout: dropout rate\n            activation: activation function\n            nhead: num heads in multi-head attention\n            dec_n_points: number of sampling points in decoder\n            return_intermediate_dec: return the intermediate results of decoder\n            query_dim: 4 -> (x, y, w, h)\n            dec_layer_share: whether to share each decoder layer\n            semantic_ce_loss: use ce loss for semantic segmentation\n        \"\"\"\n        super().__init__()\n\n        assert mask_classification, \"Only support mask classification model\"\n        self.mask_classification = mask_classification\n        self.num_feature_levels = total_num_feature_levels\n        self.initial_pred = initial_pred\n\n        # define Transformer decoder here\n        self.dn=dn\n        self.learn_tgt = learn_tgt\n        self.noise_scale=noise_scale\n        self.dn_num=dn_num\n        self.num_heads = nheads\n        self.num_layers = dec_layers\n        self.two_stage=two_stage\n        self.initialize_box_type = initialize_box_type\n        self.total_num_feature_levels = total_num_feature_levels\n\n        self.num_queries = num_queries\n        self.num_queries_test=num_queries_test\n        self.semantic_ce_loss = semantic_ce_loss\n        self.no_update=no_update\n        # learnable query features\n        # if not two_stage or self.learn_tgt:\n        self.num_queries_stuff=num_queries_stuff\n        self.query_feat = nn.Embedding(num_queries_stuff, hidden_dim)\n        # if not two_stage and initialize_box_type == 'no':\n        self.query_embed = nn.Embedding(num_queries_stuff, 4)\n        if two_stage:\n            self.enc_output = nn.Linear(hidden_dim, hidden_dim)\n            self.enc_output_norm = nn.LayerNorm(hidden_dim)\n\n        self.input_proj = nn.ModuleList()\n        for _ in range(self.num_feature_levels):\n            if in_channels != hidden_dim or enforce_input_project:\n                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))\n                weight_init.c2_xavier_fill(self.input_proj[-1])\n            else:\n                self.input_proj.append(nn.Sequential())\n        self.num_classes=num_classes\n        # output FFNs\n        assert self.mask_classification, \"why not class embedding?\"\n        # self.label_enc=nn.Embedding(505, hidden_dim)  # this is a hack for o365+coco (365+133=498)\n        self.dim_proj = dim_proj\n        self.lang_encoder = lang_encoder\n        self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim))\n        trunc_normal_(self.lang_mapper, std=.02)\n        self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))\n        trunc_normal_(self.class_embed, std=.02)\n\n        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)\n\n        # init decoder\n        self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim)\n        decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward,\n                                                          dropout, activation,\n                                                          self.num_feature_levels, nhead, dec_n_points)\n        self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm,\n                                          return_intermediate=return_intermediate_dec,\n                                          d_model=hidden_dim, query_dim=query_dim,\n                                          num_feature_levels=self.num_feature_levels,\n                                          dec_layer_share=dec_layer_share,\n                                          )\n\n        self.hidden_dim = hidden_dim\n        self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)\n        nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)\n        nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)\n        box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)]  # share box prediction each layer\n        self.bbox_embed = nn.ModuleList(box_embed_layerlist)\n        self.decoder.bbox_embed = self.bbox_embed\n\n\n    @classmethod\n    def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):\n        ret = {}\n        ret[\"in_channels\"] = in_channels\n        ret[\"lang_encoder\"] = lang_encoder\n        ret[\"mask_classification\"] = mask_classification\n\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret[\"num_classes\"] = enc_cfg['NUM_CLASSES']\n        ret[\"hidden_dim\"] = dec_cfg['HIDDEN_DIM']\n        ret[\"dim_proj\"] = cfg['MODEL']['DIM_PROJ']\n        ret[\"num_queries\"] = dec_cfg['NUM_OBJECT_QUERIES']\n        ret[\"num_queries_test\"] = dec_cfg.get('NUM_OBJECT_QUERIES_TEST',300)\n\n        # Transformer parameters:\n        ret[\"nheads\"] = dec_cfg['NHEADS']\n        ret[\"dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']\n        ret[\"dec_layers\"] = dec_cfg['DEC_LAYERS']\n        ret[\"enforce_input_project\"] = dec_cfg['ENFORCE_INPUT_PROJ']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"two_stage\"] = dec_cfg['TWO_STAGE']\n        ret[\"initialize_box_type\"] = dec_cfg['INITIALIZE_BOX_TYPE']  # ['no', 'bitmask', 'mask2box']\n        ret[\"dn\"] = dec_cfg['DN']\n        ret[\"noise_scale\"] = dec_cfg['DN_NOISE_SCALE']\n        ret[\"dn_num\"] = dec_cfg['DN_NUM']\n        ret[\"initial_pred\"] = dec_cfg['INITIAL_PRED']\n        ret[\"learn_tgt\"] = dec_cfg['LEARN_TGT']\n        ret[\"total_num_feature_levels\"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS']\n        ret[\"semantic_ce_loss\"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON']\n        ret[\"no_update\"]=dec_cfg.get(\"no_update\",False)\n\n        return ret\n\n    def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size,task=\"other\"):\n        \"\"\"\n        modified from dn-detr. You can refer to dn-detr\n        https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py\n        for more details\n            :param dn_args: scalar, noise_scale\n            :param tgt: original tgt (content) in the matching part\n            :param refpoint_emb: positional anchor queries in the matching part\n            :param batch_size: bs\n            \"\"\"\n        if self.training:\n            scalar, noise_scale = self.dn_num, self.noise_scale\n\n            known = [(torch.ones_like(t['labels'])).cuda() for t in targets]\n            know_idx = [torch.nonzero(t) for t in known]\n            known_num = [sum(k) for k in known]\n\n            # use fix number of dn queries\n            if max(known_num) > 0:\n                scalar = scalar // (int(max(known_num)))\n            else:\n                scalar = 0\n            if task==\"cls\":\n                scalar=1\n            if scalar == 0:\n                input_query_label = None\n                input_query_bbox = None\n                attn_mask = None\n                mask_dict = None\n                return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n            # can be modified to selectively denosie some label or boxes; also known label prediction\n            unmask_bbox = unmask_label = torch.cat(known)\n            labels = torch.cat([t['labels'] for t in targets])\n            # use languge as denosing content queries.\n            # if task == 'det':\n            #     labels = labels  # o365 start from 133 class\n            boxes = torch.cat([t['boxes'] for t in targets])\n            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])\n            # known\n            known_indice = torch.nonzero(unmask_label + unmask_bbox)\n            known_indice = known_indice.view(-1)\n\n            # noise\n            known_indice = known_indice.repeat(scalar, 1).view(-1)\n            known_labels = labels.repeat(scalar, 1).view(-1)\n            known_bid = batch_idx.repeat(scalar, 1).view(-1)\n            known_bboxs = boxes.repeat(scalar, 1)\n            known_labels_expaned = known_labels.clone()\n            known_bbox_expand = known_bboxs.clone()\n\n            if noise_scale > 0:\n                diff = torch.zeros_like(known_bbox_expand)\n                diff[:, :2] = known_bbox_expand[:, 2:] / 2\n                diff[:, 2:] = known_bbox_expand[:, 2:]\n                known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),\n                                               diff).cuda() * noise_scale\n                known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)\n            if task==\"cls\":\n                known_labels_expaned=torch.zeros_like(known_labels_expaned)\n            m = known_labels_expaned.long().to('cuda')\n            # import ipdb; ipdb.set_trace()\n            if task==\"cls\":\n                input_label_embed=self.cls_emb(m)\n            else:\n                input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0,\n                                             m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper\n\n            input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n            single_pad = int(max(known_num))\n            pad_size = int(single_pad * scalar)\n\n            padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim)\n            padding_bbox = input_bbox_embed.new_zeros(pad_size, 4)\n\n            if not refpoint_emb is None:\n                input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1)\n                input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1)\n            else:\n                input_query_label = padding_label.repeat(batch_size, 1, 1)\n                input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)\n\n            # map\n            map_known_indice = input_label_embed.new_tensor([])\n            if len(known_num):\n                map_known_indice = torch.cat(\n                    [input_label_embed.new_tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]\n                map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long()\n            if len(known_bid):\n                input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed\n                input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed\n\n            tgt_size = pad_size + self.num_queries+self.num_queries_stuff\n            attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0\n            # match query cannot see the reconstruct\n            attn_mask[pad_size:, :pad_size] = True\n            # reconstruct cannot see each other\n            for i in range(scalar):\n                if i == 0:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                if i == scalar - 1:\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n                else:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n            mask_dict = {\n                'known_indice': torch.as_tensor(known_indice).long(),\n                'batch_idx': torch.as_tensor(batch_idx).long(),\n                'map_known_indice': torch.as_tensor(map_known_indice).long(),\n                'known_lbs_bboxes': (known_labels, known_bboxs),\n                'know_idx': know_idx,\n                'pad_size': pad_size,\n                'scalar': scalar,\n            }\n        else:\n            if not refpoint_emb is None:\n                input_query_label = tgt.repeat(batch_size, 1, 1)\n                input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1)\n            else:\n                input_query_label = None\n                input_query_bbox = None\n            attn_mask = None\n            mask_dict = None\n\n        # 100*batch*256\n        if not input_query_bbox is None:\n            input_query_label = input_query_label\n            input_query_bbox = input_query_bbox\n\n        return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n    def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask):\n        \"\"\"\n            post process of dn after output from the transformer\n            put the dn part in the mask_dict\n            \"\"\"\n        assert mask_dict['pad_size'] > 0\n        output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :]\n        outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :]\n        output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :]\n        outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :]\n        output_known_mask = None\n        if outputs_mask is not None:\n            output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :]\n            outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :]\n        out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]}\n\n        out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord)\n        mask_dict['output_known_lbs_bboxes']=out\n        return outputs_class, outputs_coord, outputs_mask\n\n    def get_valid_ratio(self, mask):\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def pred_box(self, reference, hs, ref0=None):\n        \"\"\"\n        :param reference: reference box coordinates from each decoder layer\n        :param hs: content\n        :param ref0: whether there are prediction from the first layer\n        \"\"\"\n        if ref0 is None:\n            outputs_coord_list = []\n        else:\n            outputs_coord_list = [ref0]\n        for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):\n            layer_delta_unsig = layer_bbox_embed(layer_hs)\n            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)\n            layer_outputs_unsig = layer_outputs_unsig.sigmoid()\n            outputs_coord_list.append(layer_outputs_unsig)\n        outputs_coord_list = torch.stack(outputs_coord_list)\n        return outputs_coord_list\n\n    def forward_cls(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None,\n                extra={}):\n        \"\"\"\n        task: seg/det\n        \"\"\"\n        assert len(x) == self.num_feature_levels\n        do_seg = False# if task is det, not do segmentation training\n        size_list = []\n        # disable mask, it does not affect performance\n        enable_mask = 0\n        if masks is not None:\n            for src in x:\n                if src.size(2) % 32 or src.size(3) % 32:\n                    enable_mask = 1\n        if enable_mask == 0:\n            masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src\n                     in x]\n        src_flatten = []\n        mask_flatten = []\n        spatial_shapes = []\n        for i in range(self.num_feature_levels):\n            idx = self.num_feature_levels - 1 - i\n            bs, c, h, w = x[idx].shape\n            size_list.append(x[i].shape[-2:])\n            spatial_shapes.append(x[idx].shape[-2:])\n            src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2))\n            mask_flatten.append(masks[i].flatten(1))\n        src_flatten = torch.cat(src_flatten, 1)  # bs, \\sum{hxw}, c\n        mask_flatten = torch.cat(mask_flatten, 1)  # bs, \\sum{hxw}\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        predictions_class = []\n        predictions_mask = []\n        # if self.two_stage:\n        #     output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes)\n        #     output_memory = self.enc_output_norm(self.enc_output(output_memory))\n        #     output_memory_ = output_memory @ self.class_embed\n        #     enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_)\n        #     enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float(\"-inf\")\n        #     # enc_outputs_class_unselected = self.class_embed(output_memory)\n        #     enc_outputs_coord_unselected = self._bbox_embed(\n        #         output_memory) + output_proposals  # (bs, \\sum{hw}, 4) unsigmoid\n        #     topk = self.num_queries\n        #     topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1]\n        #     refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1,\n        #                                            topk_proposals.unsqueeze(-1).repeat(1, 1, 4))  # unsigmoid\n        #     refpoint_embed = refpoint_embed_undetach.detach()\n        #\n        #     tgt_undetach = torch.gather(output_memory, 1,\n        #                                 topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim))  # unsigmoid\n        #     outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features,\n        #                                                                 do_seg)\n        #     tgt = tgt_undetach.detach()\n        #     if self.learn_tgt:\n        #         tgt = self.query_feat.weight[None].repeat(bs, 1, 1)\n        #     interm_outputs = dict()\n        #     interm_outputs['pred_logits'] = outputs_class\n        #     interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid()\n        #     interm_outputs['pred_masks'] = outputs_mask\n        #\n        #     if self.initialize_box_type != 'no' and do_seg:\n        #         # convert masks into boxes to better initialize box in the decoder\n        #         assert self.initial_pred\n        #         flaten_mask = outputs_mask.detach().flatten(0, 1)\n        #         h, w = outputs_mask.shape[-2:]\n        #         if self.initialize_box_type == 'bitmask':  # slower, but more accurate\n        #             refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda()\n        #         elif self.initialize_box_type == 'mask2box':  # faster conversion\n        #             refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda()\n        #         else:\n        #             assert NotImplementedError\n        #         refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h],\n        #                                                                                       dtype=torch.float).cuda()\n        #         refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4)\n        #         refpoint_embed = inverse_sigmoid(refpoint_embed)\n        # elif not self.two_stage:\n        #     tgt = self.query_feat.weight[None].repeat(bs, 1, 1)\n        #     refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1)\n\n        tgt_mask = None\n        mask_dict = None\n        # if self.dn != \"no\" and self.training:\n        assert targets is not None\n        input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n            self.prepare_for_dn(targets, None, None, x[0].shape[0],task=\"cls\")\n        # if mask_dict is not None:\n        tgt = input_query_label\n        refpoint_embed = input_query_bbox\n\n        # direct prediction from the matching and denoising part in the begining\n        if self.initial_pred:\n            outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features,\n                                                                        self.training and do_seg)\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n        # if self.dn != \"no\" and self.training and mask_dict is not None:\n        # tgt=tgt.float()\n        hs, references = self.decoder(\n            tgt=tgt.transpose(0, 1),\n            memory=src_flatten.transpose(0, 1),\n            memory_key_padding_mask=mask_flatten,\n            pos=None,\n            refpoints_unsigmoid=refpoint_embed.transpose(0, 1),\n            level_start_index=level_start_index,\n            spatial_shapes=spatial_shapes,\n            valid_ratios=valid_ratios,\n            tgt_mask=None,\n            no_update=self.no_update,\n        )\n\n        for i, output in enumerate(hs):\n            outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (\n                        self.training or (i == len(hs) - 1)) and do_seg)\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n\n        # iteratively box prediction\n        # if self.initial_pred:\n        #     out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid())\n        #     assert len(predictions_class) == self.num_layers + 1\n        # else:\n        #     out_boxes = self.pred_box(references, hs)\n        if mask_dict is not None:\n            predictions_mask = None if not do_seg else torch.stack(predictions_mask)\n            predictions_class = torch.stack(predictions_class)\n            # predictions_class, out_boxes, predictions_mask = \\\n            #     self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask)\n            predictions_class = list(predictions_class)\n\n            if predictions_mask is None:\n                predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n                for i in range(self.mask_embed.num_layers):\n                    predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                                self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                            0])  # avoid no mask loss\n                predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n            # if do_seg:\n            #     predictions_mask = list(predictions_mask)\n        elif self.training:  # this is to insure self.label_enc participate in the model\n            predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n            for i in range(self.mask_embed.num_layers):\n                predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                        self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                    0])  # avoid no mask loss\n            predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n        out = {\n            'pred_logits': predictions_class[-1],\n            'pred_masks': None if not do_seg else predictions_mask[-1],\n            # 'pred_boxes': out_boxes[-1],\n            'aux_outputs': self._set_aux_loss(\n                predictions_class if self.mask_classification else None, predictions_mask\n            )\n        }\n        # if self.two_stage:\n        #     out['interm_outputs'] = interm_outputs\n\n        return out, None\n\n    def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}):\n        \"\"\"\n        task: seg/det\n        \"\"\"\n        assert len(x) == self.num_feature_levels\n        do_seg = (task != 'det')   # if task is det, not do segmentation training\n        size_list = []\n        # disable mask, it does not affect performance\n        enable_mask = 0\n        if masks is not None:\n            for src in x:\n                if src.size(2) % 32 or src.size(3) % 32:\n                    enable_mask = 1\n        if enable_mask == 0:\n            masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x]\n        src_flatten = []\n        mask_flatten = []\n        spatial_shapes = []\n        for i in range(self.num_feature_levels):\n            idx=self.num_feature_levels-1-i\n            bs, c , h, w=x[idx].shape\n            size_list.append(x[i].shape[-2:])\n            spatial_shapes.append(x[idx].shape[-2:])\n            src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2))\n            mask_flatten.append(masks[i].flatten(1))\n        src_flatten = torch.cat(src_flatten, 1)  # bs, \\sum{hxw}, c\n        mask_flatten = torch.cat(mask_flatten, 1)  # bs, \\sum{hxw}\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        predictions_class = []\n        predictions_mask = []\n        # if self.two_stage:\n        output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes)\n        output_memory = self.enc_output_norm(self.enc_output(output_memory))\n        output_memory_ = output_memory @ self.class_embed\n        enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_)\n        enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float(\"-inf\")\n        # enc_outputs_class_unselected = self.class_embed(output_memory)\n        enc_outputs_coord_unselected = self._bbox_embed(\n            output_memory) + output_proposals  # (bs, \\sum{hw}, 4) unsigmoid\n        topk = self.num_queries if self.training else self.num_queries_test\n        topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1]\n        refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1,\n                                               topk_proposals.unsqueeze(-1).repeat(1, 1, 4))  # unsigmoid\n\n        tgt_undetach = torch.gather(output_memory, 1,\n                              topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim))  # unsigmoid\n        tgt_stuff = self.query_feat.weight[None].repeat(bs, 1, 1)\n        refpoint_embed_stuff = self.query_embed.weight[None].repeat(bs, 1, 1)\n        # if not (self.)\n        tgt_undetach=torch.cat([tgt_undetach,tgt_stuff],dim=1)\n        refpoint_embed_undetach=torch.cat([refpoint_embed_undetach,refpoint_embed_stuff],dim=1)\n        refpoint_embed = refpoint_embed_undetach.detach()\n        outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg)\n        tgt = tgt_undetach.detach()\n        if self.learn_tgt:\n            tgt = self.query_feat.weight[None].repeat(bs, 1, 1)\n        interm_outputs=dict()\n        interm_outputs['pred_logits'] = outputs_class\n        interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid()\n        interm_outputs['pred_masks'] = outputs_mask\n\n        if self.initialize_box_type != 'no' and do_seg:\n            # convert masks into boxes to better initialize box in the decoder\n            assert self.initial_pred\n            flaten_mask = outputs_mask.detach().flatten(0, 1)\n            h, w = outputs_mask.shape[-2:]\n            if self.initialize_box_type == 'bitmask':  # slower, but more accurate\n                refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda()\n            elif self.initialize_box_type == 'mask2box':  # faster conversion\n                refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda()\n            else:\n                assert NotImplementedError\n            refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h],\n                                                                                          dtype=torch.float).cuda()\n            refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4)\n            refpoint_embed = inverse_sigmoid(refpoint_embed)\n        # elif not self.two_stage:\n\n        tgt_mask = None\n        mask_dict = None\n        if self.dn != \"no\" and self.training:\n            assert targets is not None\n            input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n                self.prepare_for_dn(targets, None, None, x[0].shape[0])\n            if mask_dict is not None:\n                tgt=torch.cat([input_query_label, tgt],dim=1)\n\n        # direct prediction from the matching and denoising part in the begining\n        if self.initial_pred:\n            outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg)\n            if not (task == 'seg' or not self.training):\n                outputs_class=outputs_class[:,:-self.num_queries_stuff]\n                # outputs_mask=outputs_mask[:,:-self.num_queries_stuff]\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n        if self.dn != \"no\" and self.training and mask_dict is not None:\n            refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1)\n\n        hs, references = self.decoder(\n            tgt=tgt.transpose(0, 1),\n            memory=src_flatten.transpose(0, 1),\n            memory_key_padding_mask=mask_flatten,\n            pos=None,\n            refpoints_unsigmoid=refpoint_embed.transpose(0, 1),\n            level_start_index=level_start_index,\n            spatial_shapes=spatial_shapes,\n            valid_ratios=valid_ratios,\n            tgt_mask=tgt_mask\n        )\n        if not (task=='seg' or not self.training):\n            hs=[hs_[:,:-self.num_queries_stuff] for hs_ in hs]\n            references=[references_[:,:-self.num_queries_stuff] for references_ in references]\n            refpoint_embed=refpoint_embed[:,:-self.num_queries_stuff]\n\n        for i, output in enumerate(hs):\n            outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg)\n            predictions_class.append(outputs_class)\n            predictions_mask.append(outputs_mask)\n\n        # iteratively box prediction\n        if self.initial_pred:\n            out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid())\n            assert len(predictions_class) == self.num_layers + 1\n        else:\n            out_boxes = self.pred_box(references, hs)\n        if mask_dict is not None:\n            predictions_mask = None if not do_seg else torch.stack(predictions_mask)\n            predictions_class =torch.stack(predictions_class)\n            predictions_class, out_boxes,predictions_mask=\\\n                self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask)\n            predictions_class = list(predictions_class)\n\n            if predictions_mask is None:\n                predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n                for i in range(self.mask_embed.num_layers):\n                    predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0])  # avoid no mask loss\n                predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n            if do_seg:\n                predictions_mask = list(predictions_mask)\n        elif self.training:  # this is to insure self.label_enc participate in the model\n            predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0]\n            for i in range(self.mask_embed.num_layers):\n                predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                            self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                        0])  # avoid no mask loss\n            predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n        out = {\n            'pred_logits': predictions_class[-1],\n            'pred_masks': None if not do_seg else predictions_mask[-1],\n            'pred_boxes':out_boxes[-1],\n            'aux_outputs': self._set_aux_loss(\n                predictions_class if self.mask_classification else None, predictions_mask,out_boxes\n            )\n        }\n        if self.two_stage:\n            out['interm_outputs'] = interm_outputs\n\n        return out, mask_dict\n\n    def forward_prediction_heads(self, output, mask_features, pred_mask=True):\n        decoder_output = self.decoder_norm(output)\n        decoder_output = decoder_output.transpose(0, 1)\n\n        class_embed = decoder_output @ self.class_embed\n        outputs_class = self.lang_encoder.compute_similarity(class_embed)\n\n        outputs_mask = None\n        if pred_mask:\n            mask_embed = self.mask_embed(decoder_output)\n            outputs_mask = torch.einsum(\"bqc,bchw->bqhw\", mask_embed, mask_features)\n\n        return outputs_class, outputs_mask\n\n    @torch.jit.unused\n    def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None):\n        # this is a workaround to make torchscript happy, as torchscript\n        # doesn't support dictionary with non-homogeneous values, such\n        # as a dict having both a Tensor and a list.\n        # if self.mask_classification:\n        if out_boxes is None:\n            if outputs_seg_masks is None:\n                return [\n                    {\"pred_logits\": a}\n                    for a in outputs_class[:-1]\n                ]\n            else:\n                return [\n                    {\"pred_logits\": a, \"pred_masks\": b}\n                    for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])\n                ]\n        elif outputs_seg_masks is None:\n            return [\n                {\"pred_logits\": a, \"pred_boxes\": c}\n                for a, c in zip(outputs_class[:-1], out_boxes[:-1])\n            ]\n        else:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b, \"pred_boxes\":c}\n                for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1])\n            ]\n\n@register_decoder\ndef get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra):\n    return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra)\n"
  },
  {
    "path": "llava/model/openseed/body/decoder/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_decoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/openseed/body/decoder/utils/__init__.py",
    "content": "from .utils import *"
  },
  {
    "path": "llava/model/openseed/body/decoder/utils/dino_decoder.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\n\nfrom typing import Optional, List, Union\nimport torch\nfrom torch import nn, Tensor\nfrom torch.cuda.amp import autocast\n\nfrom .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid\nfrom ...encoder.ops.modules import MSDeformAttn\n\n\nclass TransformerDecoder(nn.Module):\n\n    def __init__(self, decoder_layer, num_layers, norm=None,\n                 return_intermediate=False,\n                 d_model=256, query_dim=4,\n                 modulate_hw_attn=True,\n                 num_feature_levels=1,\n                 deformable_decoder=True,\n                 decoder_query_perturber=None,\n                 dec_layer_number=None,  # number of queries each layer in decoder\n                 rm_dec_query_scale=True,\n                 dec_layer_share=False,\n                 dec_layer_dropout_prob=None,\n                 task_switch=None,\n                 ):\n        super().__init__()\n        if num_layers > 0:\n            self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)\n        else:\n            self.layers = []\n        self.num_layers = num_layers\n        self.norm = norm\n        self.return_intermediate = return_intermediate\n        assert return_intermediate, \"support return_intermediate only\"\n        self.query_dim = query_dim\n        assert query_dim in [2, 4], \"query_dim should be 2/4 but {}\".format(query_dim)\n        self.num_feature_levels = num_feature_levels\n\n        self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)\n        if not deformable_decoder:\n            self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)\n        else:\n            self.query_pos_sine_scale = None\n\n        if rm_dec_query_scale:\n            self.query_scale = None\n        else:\n            raise NotImplementedError\n            self.query_scale = MLP(d_model, d_model, d_model, 2)\n        self.bbox_embed = None\n        self.class_embed = None\n\n        self.d_model = d_model\n        self.modulate_hw_attn = modulate_hw_attn\n        self.deformable_decoder = deformable_decoder\n\n        if not deformable_decoder and modulate_hw_attn:\n            self.ref_anchor_head = MLP(d_model, d_model, 2, 2)\n        else:\n            self.ref_anchor_head = None\n\n        self.decoder_query_perturber = decoder_query_perturber\n        self.box_pred_damping = None\n\n        self.dec_layer_number = dec_layer_number\n        if dec_layer_number is not None:\n            assert isinstance(dec_layer_number, list)\n            assert len(dec_layer_number) == num_layers\n            # assert dec_layer_number[0] ==\n\n        self.dec_layer_dropout_prob = dec_layer_dropout_prob\n        if dec_layer_dropout_prob is not None:\n            assert isinstance(dec_layer_dropout_prob, list)\n            assert len(dec_layer_dropout_prob) == num_layers\n            for i in dec_layer_dropout_prob:\n                assert 0.0 <= i <= 1.0\n\n        self.task_switch = task_switch\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n        for m in self.modules():\n            if isinstance(m, MSDeformAttn):\n                m._reset_parameters()\n\n    def forward(self, tgt, memory,\n                tgt_mask: Optional[Tensor] = None,\n                memory_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None,\n                refpoints_unsigmoid: Optional[Tensor] = None,  # num_queries, bs, 2\n                # for memory\n                level_start_index: Optional[Tensor] = None,  # num_levels\n                spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2\n                valid_ratios: Optional[Tensor] = None,\n                # misc\n                extra: Optional[Tensor] = {}, # extra information\n                ):\n        \"\"\"\n        Input:\n            - tgt: nq, bs, d_model\n            - memory: hw, bs, d_model\n            - pos: hw, bs, d_model\n            - refpoints_unsigmoid: nq, bs, 2/4\n            - valid_ratios/spatial_shapes: bs, nlevel, 2\n        \"\"\"\n        output = tgt\n\n        intermediate = []\n        reference_points = refpoints_unsigmoid.sigmoid()\n        ref_points = [reference_points]\n\n        if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys():\n            reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0)\n            output = torch.cat((output, extra['grounding_tokens']), dim=0)\n\n        for layer_id, layer in enumerate(self.layers):            \n            # preprocess ref points\n            if self.training and self.decoder_query_perturber is not None and layer_id != 0:\n                reference_points = self.decoder_query_perturber(reference_points)\n\n            reference_points_input = reference_points[:, :, None] \\\n                                         * torch.cat([valid_ratios, valid_ratios], -1)[None, :]  # nq, bs, nlevel, 4\n            # print('reference_points_input', reference_points_input.dtype)\n            # print('memory', memory.dtype)\n            reference_points_input=reference_points_input.to(memory.dtype)\n            query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2\n            # print('query_sine_embed', query_sine_embed.dtype)\n\n            raw_query_pos = self.ref_point_head(query_sine_embed)  # nq, bs, 256\n            pos_scale = self.query_scale(output) if self.query_scale is not None else 1\n            query_pos = pos_scale * raw_query_pos\n\n            output = layer(\n                tgt=output,\n                tgt_query_pos=query_pos,\n                tgt_query_sine_embed=query_sine_embed,\n                tgt_key_padding_mask=tgt_key_padding_mask,\n                tgt_reference_points=reference_points_input,\n\n                memory=memory,\n                memory_key_padding_mask=memory_key_padding_mask,\n                memory_level_start_index=level_start_index,\n                memory_spatial_shapes=spatial_shapes,\n                memory_pos=pos,\n\n                self_attn_mask=tgt_mask,\n                cross_attn_mask=memory_mask,\n\n                task_switch=self.task_switch,\n                extra=extra,\n            )\n\n            # grounding language token reference point will not update and saved\n            if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n                _reference_points = reference_points[-extra['grounding_len']:]\n                reference_points = reference_points[:-extra['grounding_len']]\n                _output = output[-extra['grounding_len']:]\n                output = output[:-extra['grounding_len']]\n\n            # iter update\n            if self.bbox_embed is not None:\n                reference_before_sigmoid = inverse_sigmoid(reference_points)\n                delta_unsig = self.bbox_embed[layer_id](output)\n                outputs_unsig = delta_unsig + reference_before_sigmoid\n                new_reference_points = outputs_unsig.sigmoid()\n\n                reference_points = new_reference_points.detach()\n                # if layer_id != self.num_layers - 1:\n                ref_points.append(new_reference_points)\n\n            intermediate.append(self.norm(output))\n\n            # add back grounding language token\n            if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n                reference_points = torch.cat((reference_points, _reference_points))\n                output = torch.cat((output, _output))\n\n        return [\n            [itm_out.transpose(0, 1) for itm_out in intermediate],\n            [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]\n        ]\n\n\nclass DeformableTransformerDecoderLayer(nn.Module):\n\n    def __init__(self, d_model=256, d_ffn=1024,\n                 dropout=0.1, activation=\"relu\",\n                 n_levels=4, n_heads=8, n_points=4,\n                 use_deformable_box_attn=False,\n                 key_aware_type=None,\n                 ):\n        super().__init__()\n\n        # cross attention\n        if use_deformable_box_attn:\n            raise NotImplementedError\n        else:\n            self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n        self.dropout1 = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n\n        # self attention\n        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.norm2 = nn.LayerNorm(d_model)\n\n        # ffn\n        self.linear1 = nn.Linear(d_model, d_ffn)\n        self.activation = _get_activation_fn(activation)\n        self.dropout3 = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(d_ffn, d_model)\n        self.dropout4 = nn.Dropout(dropout)\n        self.norm3 = nn.LayerNorm(d_model)\n\n        self.key_aware_type = key_aware_type\n        self.key_aware_proj = None\n\n    def rm_self_attn_modules(self):\n        self.self_attn = None\n        self.dropout2 = None\n        self.norm2 = None\n\n    @staticmethod\n    def with_pos_embed(tensor, pos):\n        return tensor if pos is None else tensor + pos\n\n    def forward_ffn(self, tgt):\n        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout4(tgt2)\n        tgt = self.norm3(tgt)\n        return tgt\n\n    @autocast(enabled=False)\n    def forward(self,\n                # for tgt\n                tgt: Optional[Tensor],  # nq, bs, d_model\n                tgt_query_pos: Optional[Tensor] = None,  # pos for query. MLP(Sine(pos))\n                tgt_query_sine_embed: Optional[Tensor] = None,  # pos for query. Sine(pos)\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                tgt_reference_points: Optional[Tensor] = None,  # nq, bs, 4\n\n                # for memory\n                memory: Optional[Tensor] = None,  # hw, bs, d_model\n                memory_key_padding_mask: Optional[Tensor] = None,\n                memory_level_start_index: Optional[Tensor] = None,  # num_levels\n                memory_spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2\n                memory_pos: Optional[Tensor] = None,  # pos for memory\n\n                # sa\n                self_attn_mask: Optional[Tensor] = None,  # mask used for self-attention\n                cross_attn_mask: Optional[Tensor] = None,  # mask used for cross-attention\n\n                # misc\n                task_switch: Optional[Tensor] = {}, # extra information                \n                extra: Optional[Tensor] = {}, # extra information\n                ):\n        \"\"\"\n        Input:\n            - tgt/tgt_query_pos: nq, bs, d_model\n            -\n        \"\"\"\n        # self attention\n        if self.self_attn is not None:\n            q = k = self.with_pos_embed(tgt, tgt_query_pos)\n            tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]\n            tgt = tgt + self.dropout2(tgt2)\n            tgt = self.norm2(tgt)\n\n        # exclude grounding token for cross attention\n        if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n            _grounding_lang_tokens = tgt[-extra['grounding_len']:,]\n            _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,]\n            _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,]\n            tgt = tgt[:-extra['grounding_len'],]\n            tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],]\n            tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],]\n\n        # cross attention\n        if self.key_aware_type is not None:\n            if self.key_aware_type == 'mean':\n                tgt = tgt + memory.mean(0, keepdim=True)\n            elif self.key_aware_type == 'proj_mean':\n                tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)\n            else:\n                raise NotImplementedError(\"Unknown key_aware_type: {}\".format(self.key_aware_type))\n        tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),\n                               tgt_reference_points.transpose(0, 1).contiguous(),\n                               memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index,\n                               memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better\n        tgt = tgt + self.dropout1(tgt2)\n\n        # add back grounding token for self attention\n        if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n            tgt = torch.cat((tgt, _grounding_lang_tokens))\n\n        tgt = self.norm1(tgt)\n        tgt = self.forward_ffn(tgt) # ffn\n        return tgt"
  },
  {
    "path": "llava/model/openseed/body/decoder/utils/utils.py",
    "content": "import torch\nimport copy\nfrom torch import nn, Tensor\nimport os\n\nimport math\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass MLP(nn.Module):\n    \"\"\" Very simple multi-layer perceptron (also called FFN)\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n\n    def forward(self, x):\n        for i, layer in enumerate(self.layers):\n            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        return x\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1/x2)\n\n\ndef gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor):\n    \"\"\"\n    Input:\n        - memory: bs, \\sum{hw}, d_model\n        - memory_padding_mask: bs, \\sum{hw}\n        - spatial_shapes: nlevel, 2\n    Output:\n        - output_memory: bs, \\sum{hw}, d_model\n        - output_proposals: bs, \\sum{hw}, 4\n    \"\"\"\n    N_, S_, C_ = memory.shape\n    base_scale = 4.0\n    proposals = []\n    _cur = 0\n    for lvl, (H_, W_) in enumerate(spatial_shapes):\n        mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)\n        valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)\n        valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)\n\n        grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),\n                                        torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))\n        grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)\n\n        scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)\n        grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale\n        wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)\n        proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)\n        proposals.append(proposal)\n        _cur += (H_ * W_)\n    output_proposals = torch.cat(proposals, 1)\n    output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)\n    output_proposals = torch.log(output_proposals / (1 - output_proposals))\n    output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))\n    output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))\n\n    output_memory = memory\n    output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))\n    output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))\n    return output_memory, output_proposals\n\n\ndef gen_sineembed_for_position(pos_tensor, dim=128):\n    # n_query, bs, _ = pos_tensor.size()\n    # sineembed_tensor = torch.zeros(n_query, bs, 256)\n    scale = 2 * math.pi\n    dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)\n    dim_t = 10000 ** (2 * (dim_t // 2) / dim)\n    x_embed = pos_tensor[:, :, 0] * scale\n    y_embed = pos_tensor[:, :, 1] * scale\n    pos_x = x_embed[:, :, None] / dim_t\n    pos_y = y_embed[:, :, None] / dim_t\n    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)\n    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)\n    if pos_tensor.size(-1) == 2:\n        pos = torch.cat((pos_y, pos_x), dim=2)\n    elif pos_tensor.size(-1) == 4:\n        w_embed = pos_tensor[:, :, 2] * scale\n        pos_w = w_embed[:, :, None] / dim_t\n        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)\n\n        h_embed = pos_tensor[:, :, 3] * scale\n        pos_h = h_embed[:, :, None] / dim_t\n        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)\n\n        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)\n    else:\n        raise ValueError(\"Unknown pos_tensor shape(-1):{}\".format(pos_tensor.size(-1)))\n    return pos.to(pos_tensor.dtype)\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    if activation == \"prelu\":\n        return nn.PReLU()\n    if activation == \"selu\":\n        return F.selu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\n\ndef _get_clones(module, N, layer_share=False):\n\n    if layer_share:\n        return nn.ModuleList([module for i in range(N)])\n    else:\n        return nn.ModuleList([copy.deepcopy(module) for i in range(N)])"
  },
  {
    "path": "llava/model/openseed/body/encoder/__init__.py",
    "content": "from .build import build_encoder"
  },
  {
    "path": "llava/model/openseed/body/encoder/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\nfrom .transformer_encoder_fpn import *\nfrom .encoder_deform import *\n\ndef build_encoder(config, *args, **kwargs):\n    model_name = config['MODEL']['ENCODER']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, *args, **kwargs)"
  },
  {
    "path": "llava/model/openseed/body/encoder/encoder_deform.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\nimport logging\nimport numpy as np\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\nimport fvcore.nn.weight_init as weight_init\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import xavier_uniform_, constant_, uniform_, normal_\nfrom torch.cuda.amp import autocast\n\nfrom detectron2.layers import Conv2d, ShapeSpec, get_norm\n# from detectron2.modeling import SEM_SEG_HEADS_REGISTRY\n\nfrom .registry import register_encoder\nfrom ...utils import configurable\nfrom ...modules import PositionEmbeddingSine\nfrom ..transformer_blocks import _get_clones, _get_activation_fn\nfrom .ops.modules import MSDeformAttn\nfrom torch.utils import checkpoint\n\n# MSDeformAttn Transformer encoder in deformable detr\nclass MSDeformAttnTransformerEncoderOnly(nn.Module):\n    def __init__(self, d_model=256, nhead=8,\n                 num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,\n                 activation=\"relu\",\n                 num_feature_levels=4, enc_n_points=4,):\n        super().__init__()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n        encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,\n                                                            dropout, activation,\n                                                            num_feature_levels, nhead, enc_n_points)\n        self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)\n\n        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n        for m in self.modules():\n            if isinstance(m, MSDeformAttn):\n                m._reset_parameters()\n        normal_(self.level_embed)\n\n    def get_valid_ratio(self, mask):\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def forward(self, srcs, masks, pos_embeds, use_ckpt=False):\n\n        enable_mask=0\n        if masks is not None:\n            for src in srcs:\n                if src.size(2)%32 or src.size(3)%32:\n                    enable_mask = 1\n        if enable_mask==0:\n            masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]\n        # prepare input for encoder\n        src_flatten = []\n        mask_flatten = []\n        lvl_pos_embed_flatten = []\n        spatial_shapes = []\n        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):\n            bs, c, h, w = src.shape\n            spatial_shape = (h, w)\n            spatial_shapes.append(spatial_shape)\n            src = src.flatten(2).transpose(1, 2)\n            mask = mask.flatten(1)\n            pos_embed = pos_embed.flatten(2).transpose(1, 2)\n            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)\n            lvl_pos_embed_flatten.append(lvl_pos_embed)\n            src_flatten.append(src)\n            mask_flatten.append(mask)\n        src_flatten = torch.cat(src_flatten, 1)\n        mask_flatten = torch.cat(mask_flatten, 1)\n        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        # encoder\n        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt)\n        return memory, spatial_shapes, level_start_index\n\n\nclass MSDeformAttnTransformerEncoderLayer(nn.Module):\n    def __init__(self,\n                 d_model=256, d_ffn=1024,\n                 dropout=0.1, activation=\"relu\",\n                 n_levels=4, n_heads=8, n_points=4):\n        super().__init__()\n\n        # self attention\n        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n        self.dropout1 = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n\n        # ffn\n        self.linear1 = nn.Linear(d_model, d_ffn)\n        self.activation = _get_activation_fn(activation)\n        self.dropout2 = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(d_ffn, d_model)\n        self.dropout3 = nn.Dropout(dropout)\n        self.norm2 = nn.LayerNorm(d_model)\n\n    @staticmethod\n    def with_pos_embed(tensor, pos):\n        return tensor if pos is None else tensor + pos\n\n    def forward_ffn(self, src):\n        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n        src = src + self.dropout3(src2)\n        src = self.norm2(src)\n        return src\n\n    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):\n        # self attention\n        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n\n        # ffn\n        src = self.forward_ffn(src)\n\n        return src\n\n\nclass MSDeformAttnTransformerEncoder(nn.Module):\n    def __init__(self, encoder_layer, num_layers):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n\n    @staticmethod\n    def get_reference_points(spatial_shapes, valid_ratios, device):\n        reference_points_list = []\n        for lvl, (H_, W_) in enumerate(spatial_shapes):\n\n            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),\n                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))\n            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)\n            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)\n            ref = torch.stack((ref_x, ref_y), -1)\n            reference_points_list.append(ref)\n        reference_points = torch.cat(reference_points_list, 1)\n        reference_points = reference_points[:, :, None] * valid_ratios[:, None]\n        return reference_points\n\n    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False):\n        output = src\n        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)\n        for _, layer in enumerate(self.layers):\n            if use_ckpt:\n                output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)\n            else:\n                output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)\n\n        return output\n\n\nclass OpenSeeDEncoder(nn.Module):\n    \"\"\"\n    This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models.\n    \"\"\"\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        transformer_dropout: float,\n        transformer_nheads: int,\n        transformer_dim_feedforward: int,\n        transformer_enc_layers: int,\n        conv_dim: int,\n        mask_dim: int,\n        norm: Optional[Union[str, Callable]] = None,\n        # deformable transformer encoder args\n        transformer_in_features: List[str],\n        common_stride: int,\n        num_feature_levels: int,\n        total_num_feature_levels: int,\n        feature_order: str,\n        use_ckpt=False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            transformer_dropout: dropout probability in transformer\n            transformer_nheads: number of heads in transformer\n            transformer_dim_feedforward: dimension of feedforward network\n            transformer_enc_layers: number of transformer encoder layers\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n            num_feature_levels: feature scales used\n            total_num_feature_levels: total feautre scales used (include the downsampled features)\n            feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first.\n        \"\"\"\n        super().__init__()\n        self.use_ckpt = use_ckpt\n        transformer_input_shape = {\n            k: v for k, v in input_shape.items() if k in transformer_in_features\n        }\n        # this is the input shape of pixel decoder\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        self.feature_strides = [v.stride for k, v in input_shape]\n        self.feature_channels = [v.channels for k, v in input_shape]\n        self.feature_order = feature_order\n\n        if feature_order == \"low2high\":\n            transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride)\n        else:\n            transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)\n        self.transformer_in_features = [k for k, v in transformer_input_shape]  # starting from \"res2\" to \"res5\"\n        transformer_in_channels = [v.channels for k, v in transformer_input_shape]\n        self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape]  # to decide extra FPN layers\n\n        self.maskdino_num_feature_levels = num_feature_levels  # always use 3 scales\n        self.total_num_feature_levels = total_num_feature_levels\n        self.common_stride = common_stride\n\n        self.transformer_num_feature_levels = len(self.transformer_in_features)\n        self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels))\n        self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1\n        if self.transformer_num_feature_levels > 1:\n            input_proj_list = []\n            for in_channels in transformer_in_channels[::-1]:\n                input_proj_list.append(nn.Sequential(\n                    nn.Conv2d(in_channels, conv_dim, kernel_size=1),\n                    nn.GroupNorm(32, conv_dim),\n                ))\n            # input projectino for downsample\n            in_channels = max(transformer_in_channels)\n            for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels):  # exclude the res2\n                input_proj_list.append(nn.Sequential(\n                    nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1),\n                    nn.GroupNorm(32, conv_dim),\n                ))\n                in_channels = conv_dim\n            self.input_proj = nn.ModuleList(input_proj_list)\n        else:\n            self.input_proj = nn.ModuleList([\n                nn.Sequential(\n                    nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),\n                    nn.GroupNorm(32, conv_dim),\n                )])\n\n        for proj in self.input_proj:\n            nn.init.xavier_uniform_(proj[0].weight, gain=1)\n            nn.init.constant_(proj[0].bias, 0)\n\n        self.transformer = MSDeformAttnTransformerEncoderOnly(\n            d_model=conv_dim,\n            dropout=transformer_dropout,\n            nhead=transformer_nheads,\n            dim_feedforward=transformer_dim_feedforward,\n            num_encoder_layers=transformer_enc_layers,\n            num_feature_levels=self.total_num_feature_levels,\n        )\n        N_steps = conv_dim // 2\n        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)\n\n        self.mask_dim = mask_dim\n        # use 1x1 conv instead\n        self.mask_features = Conv2d(\n            conv_dim,\n            mask_dim,\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        )\n        weight_init.c2_xavier_fill(self.mask_features)\n        # extra fpn levels\n        stride = min(self.transformer_feature_strides)\n        self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1)\n\n        lateral_convs = []\n        output_convs = []\n\n        use_bias = norm == \"\"\n        for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):\n            lateral_norm = get_norm(norm, conv_dim)\n            output_norm = get_norm(norm, conv_dim)\n\n            lateral_conv = Conv2d(\n                in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm\n            )\n            output_conv = Conv2d(\n                conv_dim,\n                conv_dim,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n                bias=use_bias,\n                norm=output_norm,\n                activation=F.relu,\n            )\n            weight_init.c2_xavier_fill(lateral_conv)\n            weight_init.c2_xavier_fill(output_conv)\n            self.add_module(\"adapter_{}\".format(idx + 1), lateral_conv)\n            self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n            lateral_convs.append(lateral_conv)\n            output_convs.append(output_conv)\n        # Place convs into top-down order (from low to high resolution)\n        # to make the top-down computation in forward clearer.\n        self.lateral_convs = lateral_convs[::-1]\n        self.output_convs = output_convs[::-1]\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret = {}\n        ret[\"input_shape\"] = {\n            k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n        }\n        ret[\"conv_dim\"] = enc_cfg['CONVS_DIM']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"norm\"] = enc_cfg['NORM']\n        ret[\"transformer_dropout\"] = dec_cfg['DROPOUT']\n        ret[\"transformer_nheads\"] = dec_cfg['NHEADS']\n        ret[\"transformer_dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']  # deformable transformer encoder\n        ret[\n            \"transformer_enc_layers\"\n        ] = enc_cfg['TRANSFORMER_ENC_LAYERS']  # a separate config\n        ret[\"transformer_in_features\"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES']  # ['res3', 'res4', 'res5']\n        ret[\"common_stride\"] = enc_cfg['COMMON_STRIDE']\n        ret[\"total_num_feature_levels\"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS']\n        ret[\"num_feature_levels\"] = enc_cfg['NUM_FEATURE_LEVELS']\n        ret[\"feature_order\"] = enc_cfg['FEATURE_ORDER']\n        ret[\"use_ckpt\"] = enc_cfg.get('USE_CKPT', False)\n        return ret\n\n    @autocast(enabled=False)\n    def forward_features(self, features, masks):\n        \"\"\"\n        :param features: multi-scale features from the backbone\n        :param masks: image mask\n        :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask\n        \"\"\"\n        # backbone features\n        srcs = []\n        pos = []\n        # additional downsampled features\n        srcsl = []\n        posl = []\n        if self.total_num_feature_levels > self.transformer_num_feature_levels:\n            smallest_feat = features[self.transformer_in_features[self.low_resolution_index]]\n            _len_srcs = self.transformer_num_feature_levels\n            for l in range(_len_srcs, self.total_num_feature_levels):\n                if l == _len_srcs:\n                    src = self.input_proj[l](smallest_feat)\n                else:\n                    src = self.input_proj[l](srcsl[-1])\n                srcsl.append(src)\n                posl.append(self.pe_layer(src))\n        srcsl = srcsl[::-1]\n        # Reverse feature maps\n        for idx, f in enumerate(self.transformer_in_features[::-1]):\n            x = features[f] # deformable detr does not support half precision\n            srcs.append(self.input_proj[idx](x))\n            pos.append(self.pe_layer(x))\n        srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs)\n        pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos)\n        if self.feature_order != 'low2high':\n            srcs = srcsl\n            pos = posl\n        y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt)\n        bs = y.shape[0]\n\n        split_size_or_sections = [None] * self.total_num_feature_levels\n        for i in range(self.total_num_feature_levels):\n            if i < self.total_num_feature_levels - 1:\n                split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]\n            else:\n                split_size_or_sections[i] = y.shape[1] - level_start_index[i]\n        y = torch.split(y, split_size_or_sections, dim=1)\n\n        out = []\n        multi_scale_features = []\n        num_cur_levels = 0\n        for i, z in enumerate(y):\n            out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))\n\n        # append `out` with extra FPN levels\n        # Reverse feature maps into top-down order (from low to high resolution)\n        convert = False\n        if out[0].dtype == torch.bfloat16:\n            out = [out_.float() for out_ in out]\n            convert = True\n        for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):\n            x = features[f]\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            cur_fpn = lateral_conv(x)\n            # Following FPN implementation, we use nearest upsampling here\n            y = F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode=\"bilinear\", align_corners=False)\n            if convert:\n                y = y.bfloat16()\n            y=cur_fpn + y\n            y = output_conv(y)\n            out.append(y)\n        if convert:\n            out = [out_.bfloat16() for out_ in out]\n        for o in out:\n            if num_cur_levels < self.total_num_feature_levels:\n                multi_scale_features.append(o)\n                num_cur_levels += 1\n        return self.mask_features(out[-1]), out[0], multi_scale_features\n\n\n\n@register_encoder\ndef get_maskdino_encoder_deform(cfg, input_shape):\n    \"\"\"\n    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.\n    \"\"\"\n    model = OpenSeeDEncoder(cfg, input_shape)\n    forward_features = getattr(model, \"forward_features\", None)\n    if not callable(forward_features):\n        raise ValueError(\n            \"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. \"\n            f\"Please implement forward_features for {name} to only return mask features.\"\n        )\n    return model"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/functions/__init__.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom .ms_deform_attn_func import MSDeformAttnFunction\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/functions/ms_deform_attn_func.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\n\ntry:\n    import MultiScaleDeformableAttention as MSDA\nexcept ModuleNotFoundError as e:\n    info_string = (\n        \"\\n\\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\\n\"\n        \"\\t`cd mask2former/modeling/pixel_decoder/ops`\\n\"\n        \"\\t`sh make.sh`\\n\"\n    )\n    raise ModuleNotFoundError(info_string)\n\n\nclass MSDeformAttnFunction(Function):\n    @staticmethod\n    def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):\n        ctx.im2col_step = im2col_step\n        output = MSDA.ms_deform_attn_forward(\n            value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)\n        ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors\n        grad_value, grad_sampling_loc, grad_attn_weight = \\\n            MSDA.ms_deform_attn_backward(\n                value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)\n\n        return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None\n\n\ndef ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):\n    # for debug and test only,\n    # need to use cuda version instead\n    N_, S_, M_, D_ = value.shape\n    _, Lq_, M_, L_, P_, _ = sampling_locations.shape\n    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)\n    sampling_grids = 2 * sampling_locations - 1\n    sampling_value_list = []\n    for lid_, (H_, W_) in enumerate(value_spatial_shapes):\n        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_\n        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)\n        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2\n        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)\n        # N_*M_, D_, Lq_, P_\n        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,\n                                          mode='bilinear', padding_mode='zeros', align_corners=False)\n        sampling_value_list.append(sampling_value_l_)\n    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)\n    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)\n    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)\n    return output.transpose(1, 2).contiguous()\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/make.sh",
    "content": "#!/usr/bin/env bash\n# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\npython setup.py build install --user\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/modules/__init__.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom .ms_deform_attn import MSDeformAttn\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/modules/ms_deform_attn.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport warnings\nimport math\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.nn.init import xavier_uniform_, constant_\n\nfrom ..functions import MSDeformAttnFunction\nfrom ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch\n\n\ndef _is_power_of_2(n):\n    if (not isinstance(n, int)) or (n < 0):\n        raise ValueError(\"invalid input for _is_power_of_2: {} (type: {})\".format(n, type(n)))\n    return (n & (n-1) == 0) and n != 0\n\n\nclass MSDeformAttn(nn.Module):\n    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):\n        \"\"\"\n        Multi-Scale Deformable Attention Module\n        :param d_model      hidden dimension\n        :param n_levels     number of feature levels\n        :param n_heads      number of attention heads\n        :param n_points     number of sampling points per attention head per feature level\n        \"\"\"\n        super().__init__()\n        if d_model % n_heads != 0:\n            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))\n        _d_per_head = d_model // n_heads\n        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation\n        if not _is_power_of_2(_d_per_head):\n            warnings.warn(\"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 \"\n                          \"which is more efficient in our CUDA implementation.\")\n\n        self.im2col_step = 128\n\n        self.d_model = d_model\n        self.n_levels = n_levels\n        self.n_heads = n_heads\n        self.n_points = n_points\n\n        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)\n        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)\n        self.value_proj = nn.Linear(d_model, d_model)\n        self.output_proj = nn.Linear(d_model, d_model)\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        constant_(self.sampling_offsets.weight.data, 0.)\n        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)\n        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)\n        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)\n        for i in range(self.n_points):\n            grid_init[:, :, i, :] *= i + 1\n        with torch.no_grad():\n            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))\n        constant_(self.attention_weights.weight.data, 0.)\n        constant_(self.attention_weights.bias.data, 0.)\n        xavier_uniform_(self.value_proj.weight.data)\n        constant_(self.value_proj.bias.data, 0.)\n        xavier_uniform_(self.output_proj.weight.data)\n        constant_(self.output_proj.bias.data, 0.)\n\n    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):\n        \"\"\"\n        :param query                       (N, Length_{query}, C)\n        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area\n                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes\n        :param input_flatten               (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)\n        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]\n        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]\n        :param input_padding_mask          (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements\n\n        :return output                     (N, Length_{query}, C)\n        \"\"\"\n        N, Len_q, _ = query.shape\n        N, Len_in, _ = input_flatten.shape\n        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in\n\n        value = self.value_proj(input_flatten)\n        if input_padding_mask is not None:\n            value = value.masked_fill(input_padding_mask[..., None], float(0))\n        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)\n        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)\n        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)\n        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)\n        # N, Len_q, n_heads, n_levels, n_points, 2\n        if reference_points.shape[-1] == 2:\n            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)\n            sampling_locations = reference_points[:, :, None, :, None, :] \\\n                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]\n        elif reference_points.shape[-1] == 4:\n            sampling_locations = reference_points[:, :, None, :, None, :2] \\\n                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5\n        else:\n            raise ValueError(\n                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))\n        # try:\n        # print(value.dtype)\n        convert=False\n        if value.dtype== torch.bfloat16:\n            value = value.float()\n            attention_weights = attention_weights.float()\n            sampling_locations = sampling_locations.float()\n            convert=True\n        output = MSDeformAttnFunction.apply(\n            value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)\n        if convert:\n            output = output.bfloat16()\n        # except:\n        #     # CPU\n        #     output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)\n        # # For FLOPs calculation only\n        # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)\n        output = self.output_proj(output)\n        return output\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/setup.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nimport os\nimport glob\n\nimport torch\n\nfrom torch.utils.cpp_extension import CUDA_HOME\nfrom torch.utils.cpp_extension import CppExtension\nfrom torch.utils.cpp_extension import CUDAExtension\n\nfrom setuptools import find_packages\nfrom setuptools import setup\n\nrequirements = [\"torch\", \"torchvision\"]\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    extensions_dir = os.path.join(this_dir, \"src\")\n\n    main_file = glob.glob(os.path.join(extensions_dir, \"*.cpp\"))\n    source_cpu = glob.glob(os.path.join(extensions_dir, \"cpu\", \"*.cpp\"))\n    source_cuda = glob.glob(os.path.join(extensions_dir, \"cuda\", \"*.cu\"))\n\n    sources = main_file + source_cpu\n    extension = CppExtension\n    extra_compile_args = {\"cxx\": []}\n    define_macros = []\n\n    # Force cuda since torch ask for a device, not if cuda is in fact available.\n    if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args[\"nvcc\"] = [\n            \"-DCUDA_HAS_FP16=1\",\n            \"-D__CUDA_NO_HALF_OPERATORS__\",\n            \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n            \"-D__CUDA_NO_HALF2_OPERATORS__\",\n        ]\n    else:\n        if CUDA_HOME is None:\n            raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')\n        else:\n            raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')\n\n    sources = [os.path.join(extensions_dir, s) for s in sources]\n    include_dirs = [extensions_dir]\n    ext_modules = [\n        extension(\n            \"MultiScaleDeformableAttention\",\n            sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n        )\n    ]\n    return ext_modules\n\nsetup(\n    name=\"MultiScaleDeformableAttention\",\n    version=\"1.0\",\n    author=\"Weijie Su\",\n    url=\"https://github.com/fundamentalvision/Deformable-DETR\",\n    description=\"PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention\",\n    packages=find_packages(exclude=(\"configs\", \"tests\",)),\n    ext_modules=get_extensions(),\n    cmdclass={\"build_ext\": torch.utils.cpp_extension.BuildExtension},\n)\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <vector>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n\nat::Tensor\nms_deform_attn_cpu_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    AT_ERROR(\"Not implement on cpu\");\n}\n\nstd::vector<at::Tensor>\nms_deform_attn_cpu_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n    AT_ERROR(\"Not implement on cpu\");\n}\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n#include <torch/extension.h>\n\nat::Tensor\nms_deform_attn_cpu_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step);\n\nstd::vector<at::Tensor>\nms_deform_attn_cpu_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step);\n\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <vector>\n#include \"cuda/ms_deform_im2col_cuda.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n\nat::Tensor ms_deform_attn_cuda_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    AT_ASSERTM(value.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(spatial_shapes.is_contiguous(), \"spatial_shapes tensor has to be contiguous\");\n    AT_ASSERTM(level_start_index.is_contiguous(), \"level_start_index tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n\n    AT_ASSERTM(value.type().is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(spatial_shapes.type().is_cuda(), \"spatial_shapes must be a CUDA tensor\");\n    AT_ASSERTM(level_start_index.type().is_cuda(), \"level_start_index must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.type().is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.type().is_cuda(), \"attn_weight must be a CUDA tensor\");\n\n    const int batch = value.size(0);\n    const int spatial_size = value.size(1);\n    const int num_heads = value.size(2);\n    const int channels = value.size(3);\n\n    const int num_levels = spatial_shapes.size(0);\n\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(4);\n\n    const int im2col_step_ = std::min(batch, im2col_step);\n\n    AT_ASSERTM(batch % im2col_step_ == 0, \"batch(%d) must divide im2col_step(%d)\", batch, im2col_step_);\n    \n    auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());\n\n    const int batch_n = im2col_step_;\n    auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});\n    auto per_value_size = spatial_size * num_heads * channels;\n    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;\n    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;\n    for (int n = 0; n < batch/im2col_step_; ++n)\n    {\n        auto columns = output_n.select(0, n);\n        AT_DISPATCH_FLOATING_TYPES(value.type(), \"ms_deform_attn_forward_cuda\", ([&] {\n            ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),\n                value.data<scalar_t>() + n * im2col_step_ * per_value_size,\n                spatial_shapes.data<int64_t>(),\n                level_start_index.data<int64_t>(),\n                sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,\n                batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,\n                columns.data<scalar_t>());\n\n        }));\n    }\n\n    output = output.view({batch, num_query, num_heads*channels});\n\n    return output;\n}\n\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n\n    AT_ASSERTM(value.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(spatial_shapes.is_contiguous(), \"spatial_shapes tensor has to be contiguous\");\n    AT_ASSERTM(level_start_index.is_contiguous(), \"level_start_index tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n    AT_ASSERTM(grad_output.is_contiguous(), \"grad_output tensor has to be contiguous\");\n\n    AT_ASSERTM(value.type().is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(spatial_shapes.type().is_cuda(), \"spatial_shapes must be a CUDA tensor\");\n    AT_ASSERTM(level_start_index.type().is_cuda(), \"level_start_index must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.type().is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.type().is_cuda(), \"attn_weight must be a CUDA tensor\");\n    AT_ASSERTM(grad_output.type().is_cuda(), \"grad_output must be a CUDA tensor\");\n\n    const int batch = value.size(0);\n    const int spatial_size = value.size(1);\n    const int num_heads = value.size(2);\n    const int channels = value.size(3);\n\n    const int num_levels = spatial_shapes.size(0);\n\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(4);\n\n    const int im2col_step_ = std::min(batch, im2col_step);\n\n    AT_ASSERTM(batch % im2col_step_ == 0, \"batch(%d) must divide im2col_step(%d)\", batch, im2col_step_);\n\n    auto grad_value = at::zeros_like(value);\n    auto grad_sampling_loc = at::zeros_like(sampling_loc);\n    auto grad_attn_weight = at::zeros_like(attn_weight);\n\n    const int batch_n = im2col_step_;\n    auto per_value_size = spatial_size * num_heads * channels;\n    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;\n    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;\n    auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});\n    \n    for (int n = 0; n < batch/im2col_step_; ++n)\n    {\n        auto grad_output_g = grad_output_n.select(0, n);\n        AT_DISPATCH_FLOATING_TYPES(value.type(), \"ms_deform_attn_backward_cuda\", ([&] {\n            ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),\n                                    grad_output_g.data<scalar_t>(),\n                                    value.data<scalar_t>() + n * im2col_step_ * per_value_size,\n                                    spatial_shapes.data<int64_t>(),\n                                    level_start_index.data<int64_t>(),\n                                    sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                                    attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,\n                                    batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,\n                                    grad_value.data<scalar_t>() +  n * im2col_step_ * per_value_size,\n                                    grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                                    grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);\n\n        }));\n    }\n\n    return {\n        grad_value, grad_sampling_loc, grad_attn_weight\n    };\n}"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n#include <torch/extension.h>\n\nat::Tensor ms_deform_attn_cuda_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step);\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step);\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh",
    "content": "/*!\n**************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************\n* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)\n* Copyright (c) 2018 Microsoft\n**************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THCAtomics.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N, const int num_threads)\n{\n  return (N + num_threads - 1) / num_threads;\n}\n\n\ntemplate <typename scalar_t>\n__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n  }\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\n\ntemplate <typename scalar_t>\n__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,\n                                                   const scalar_t &top_grad,\n                                                   const scalar_t &attn_weight,\n                                                   scalar_t* &grad_value, \n                                                   scalar_t* grad_sampling_loc,\n                                                   scalar_t* grad_attn_weight)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n  const scalar_t top_grad_value = top_grad * attn_weight;\n  scalar_t grad_h_weight = 0, grad_w_weight = 0;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n    grad_h_weight -= hw * v1;\n    grad_w_weight -= hh * v1;\n    atomicAdd(grad_value+ptr1, w1*top_grad_value);\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n    grad_h_weight -= lw * v2;\n    grad_w_weight += hh * v2;\n    atomicAdd(grad_value+ptr2, w2*top_grad_value);\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n    grad_h_weight += hw * v3;\n    grad_w_weight -= lh * v3;\n    atomicAdd(grad_value+ptr3, w3*top_grad_value); \n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n    grad_h_weight += lw * v4;\n    grad_w_weight += lh * v4;\n    atomicAdd(grad_value+ptr4, w4*top_grad_value);\n  }\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  *grad_attn_weight = top_grad * val;\n  *grad_sampling_loc = width * grad_w_weight * top_grad_value;\n  *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;\n}\n\n\ntemplate <typename scalar_t>\n__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,\n                                                   const scalar_t &top_grad,\n                                                   const scalar_t &attn_weight,\n                                                   scalar_t* &grad_value, \n                                                   scalar_t* grad_sampling_loc,\n                                                   scalar_t* grad_attn_weight)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n  const scalar_t top_grad_value = top_grad * attn_weight;\n  scalar_t grad_h_weight = 0, grad_w_weight = 0;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n    grad_h_weight -= hw * v1;\n    grad_w_weight -= hh * v1;\n    atomicAdd(grad_value+ptr1, w1*top_grad_value);\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n    grad_h_weight -= lw * v2;\n    grad_w_weight += hh * v2;\n    atomicAdd(grad_value+ptr2, w2*top_grad_value);\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n    grad_h_weight += hw * v3;\n    grad_w_weight -= lh * v3;\n    atomicAdd(grad_value+ptr3, w3*top_grad_value); \n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n    grad_h_weight += lw * v4;\n    grad_w_weight += lh * v4;\n    atomicAdd(grad_value+ptr4, w4*top_grad_value);\n  }\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  atomicAdd(grad_attn_weight, top_grad * val); \n  atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);\n  atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_im2col_gpu_kernel(const int n,\n                                                const scalar_t *data_value, \n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *data_col)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    scalar_t *data_col_ptr = data_col + index;\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n    scalar_t col = 0;\n    \n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;\n        }\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n      }\n    }\n    *data_col_ptr = col;\n  }\n}\n\ntemplate <typename scalar_t, unsigned int blockSize>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];\n    __shared__ scalar_t cache_grad_attn_weight[blockSize];\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n        if (tid == 0)\n        {\n          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];\n          int sid=2;\n          for (unsigned int tid = 1; tid < blockSize; ++tid)\n          {\n            _grad_w += cache_grad_sampling_loc[sid];\n            _grad_h += cache_grad_sampling_loc[sid + 1];\n            _grad_a += cache_grad_attn_weight[tid];\n            sid += 2;\n          }\n          \n          \n          *grad_sampling_loc = _grad_w;\n          *(grad_sampling_loc + 1) = _grad_h;\n          *grad_attn_weight = _grad_a;\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t, unsigned int blockSize>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];\n    __shared__ scalar_t cache_grad_attn_weight[blockSize];\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockSize/2; s>0; s>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        { \n          *grad_sampling_loc = cache_grad_sampling_loc[0];\n          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];\n          *grad_attn_weight = cache_grad_attn_weight[0];\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n        if (tid == 0)\n        {\n          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];\n          int sid=2;\n          for (unsigned int tid = 1; tid < blockDim.x; ++tid)\n          {\n            _grad_w += cache_grad_sampling_loc[sid];\n            _grad_h += cache_grad_sampling_loc[sid + 1];\n            _grad_a += cache_grad_attn_weight[tid];\n            sid += 2;\n          }\n          \n          \n          *grad_sampling_loc = _grad_w;\n          *(grad_sampling_loc + 1) = _grad_h;\n          *grad_attn_weight = _grad_a;\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n            if (tid + (s << 1) < spre)\n            {\n              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];\n              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];\n              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];\n            } \n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        {\n          *grad_sampling_loc = cache_grad_sampling_loc[0];\n          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];\n          *grad_attn_weight = cache_grad_attn_weight[0];\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n            if (tid + (s << 1) < spre)\n            {\n              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];\n              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];\n              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];\n            }\n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        {\n          atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);\n          atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);\n          atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear_gm(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            grad_sampling_loc, grad_attn_weight);\n        }\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\nvoid ms_deformable_im2col_cuda(cudaStream_t stream,\n                              const scalar_t* data_value,\n                              const int64_t* data_spatial_shapes, \n                              const int64_t* data_level_start_index, \n                              const scalar_t* data_sampling_loc,\n                              const scalar_t* data_attn_weight,\n                              const int batch_size,\n                              const int spatial_size, \n                              const int num_heads, \n                              const int channels, \n                              const int num_levels, \n                              const int num_query,\n                              const int num_point,\n                              scalar_t* data_col)\n{\n  const int num_kernels = batch_size * num_query * num_heads * channels;\n  const int num_actual_kernels = batch_size * num_query * num_heads * channels;\n  const int num_threads = CUDA_NUM_THREADS;\n  ms_deformable_im2col_gpu_kernel<scalar_t>\n      <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n          0, stream>>>(\n      num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, \n      batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);\n  \n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in ms_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\ntemplate <typename scalar_t>\nvoid ms_deformable_col2im_cuda(cudaStream_t stream,\n                              const scalar_t* grad_col,\n                              const scalar_t* data_value,\n                              const int64_t * data_spatial_shapes,\n                              const int64_t * data_level_start_index,\n                              const scalar_t * data_sampling_loc,\n                              const scalar_t * data_attn_weight,\n                              const int batch_size, \n                              const int spatial_size, \n                              const int num_heads,\n                              const int channels, \n                              const int num_levels,\n                              const int num_query,\n                              const int num_point, \n                              scalar_t* grad_value,\n                              scalar_t* grad_sampling_loc,\n                              scalar_t* grad_attn_weight)\n{\n  const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;\n  const int num_kernels = batch_size * num_query * num_heads * channels;\n  const int num_actual_kernels = batch_size * num_query * num_heads * channels;\n  if (channels > 1024)\n  {\n    if ((channels & 1023) == 0)\n    {\n      ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n    }\n    else\n    {\n      ms_deformable_col2im_gpu_kernel_gm<scalar_t>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n    }\n  }\n  else{\n    switch(channels)\n    {\n      case 1:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 2:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 4:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 8:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 16:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 32:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 64:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 128:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 256:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 512:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 1024:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      default:\n        if (channels < 64)\n        {\n          ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n        }\n        else\n        {\n          ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n        }\n    }\n  }\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in ms_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/ms_deform_attn.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n\n#include \"cpu/ms_deform_attn_cpu.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/ms_deform_attn_cuda.h\"\n#endif\n\n\nat::Tensor\nms_deform_attn_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    if (value.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return ms_deform_attn_cuda_forward(\n            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    AT_ERROR(\"Not implemented on the CPU\");\n}\n\nstd::vector<at::Tensor>\nms_deform_attn_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n    if (value.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return ms_deform_attn_cuda_backward(\n            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    AT_ERROR(\"Not implemented on the CPU\");\n}\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/src/vision.cpp",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include \"ms_deform_attn.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"ms_deform_attn_forward\", &ms_deform_attn_forward, \"ms_deform_attn_forward\");\n  m.def(\"ms_deform_attn_backward\", &ms_deform_attn_backward, \"ms_deform_attn_backward\");\n}\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/ops/test.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport time\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import gradcheck\n\nfrom functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch\n\n\nN, M, D = 1, 2, 2\nLq, L, P = 2, 2, 2\nshapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()\nlevel_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))\nS = sum([(H*W).item() for H, W in shapes])\n\n\ntorch.manual_seed(3)\n\n\n@torch.no_grad()\ndef check_forward_equal_with_pytorch_double():\n    value = torch.rand(N, S, M, D).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()\n    output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()\n    fwdok = torch.allclose(output_cuda, output_pytorch)\n    max_abs_err = (output_cuda - output_pytorch).abs().max()\n    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()\n\n    print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')\n\n\n@torch.no_grad()\ndef check_forward_equal_with_pytorch_float():\n    value = torch.rand(N, S, M, D).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()\n    output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()\n    fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)\n    max_abs_err = (output_cuda - output_pytorch).abs().max()\n    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()\n\n    print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')\n\n\ndef check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):\n\n    value = torch.rand(N, S, M, channels).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    func = MSDeformAttnFunction.apply\n\n    value.requires_grad = grad_value\n    sampling_locations.requires_grad = grad_sampling_loc\n    attention_weights.requires_grad = grad_attn_weight\n\n    gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))\n\n    print(f'* {gradok} check_gradient_numerical(D={channels})')\n\n\nif __name__ == '__main__':\n    check_forward_equal_with_pytorch_double()\n    check_forward_equal_with_pytorch_float()\n\n    for channels in [30, 32, 64, 71, 1025, 2048, 3096]:\n        check_gradient_numerical(channels, True, True, True)\n\n\n\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_encoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints\n"
  },
  {
    "path": "llava/model/openseed/body/encoder/transformer_encoder_fpn.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport logging\nimport numpy as np\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import xavier_uniform_, constant_, uniform_, normal_\nfrom torch.cuda.amp import autocast\n\nimport fvcore.nn.weight_init as weight_init\nfrom detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm\n\nfrom .registry import register_encoder\nfrom ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn\nfrom ...modules import PositionEmbeddingSine\nfrom ...utils import configurable\n\n\n# This is a modified FPN decoder.\nclass BasePixelDecoder(nn.Module):\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        conv_dim: int,\n        mask_dim: int,\n        mask_on: bool,\n        norm: Optional[Union[str, Callable]] = None,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n        \"\"\"\n        super().__init__()\n\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        feature_channels = [v.channels for k, v in input_shape]\n\n        lateral_convs = []\n        output_convs = []\n\n        use_bias = norm == \"\"\n        for idx, in_channels in enumerate(feature_channels):\n            if idx == len(self.in_features) - 1:\n                output_norm = get_norm(norm, conv_dim)\n                output_conv = Conv2d(\n                    in_channels,\n                    conv_dim,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=use_bias,\n                    norm=output_norm,\n                    activation=F.relu,\n                )\n                weight_init.c2_xavier_fill(output_conv)\n                self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n                lateral_convs.append(None)\n                output_convs.append(output_conv)\n            else:\n                lateral_norm = get_norm(norm, conv_dim)\n                output_norm = get_norm(norm, conv_dim)\n\n                lateral_conv = Conv2d(\n                    in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm\n                )\n                output_conv = Conv2d(\n                    conv_dim,\n                    conv_dim,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=use_bias,\n                    norm=output_norm,\n                    activation=F.relu,\n                )\n                weight_init.c2_xavier_fill(lateral_conv)\n                weight_init.c2_xavier_fill(output_conv)\n                self.add_module(\"adapter_{}\".format(idx + 1), lateral_conv)\n                self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n                lateral_convs.append(lateral_conv)\n                output_convs.append(output_conv)\n        # Place convs into top-down order (from low to high resolution)\n        # to make the top-down computation in forward clearer.\n        self.lateral_convs = lateral_convs[::-1]\n        self.output_convs = output_convs[::-1]\n\n        self.mask_on = mask_on\n        if self.mask_on:\n            self.mask_dim = mask_dim\n            self.mask_features = Conv2d(\n                conv_dim,\n                mask_dim,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n            )\n            weight_init.c2_xavier_fill(self.mask_features)\n\n        self.maskformer_num_feature_levels = 3  # always use 3 scales\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        ret = {}\n        ret[\"input_shape\"] = {\n            k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n        }\n        ret[\"conv_dim\"] = enc_cfg['CONVS_DIM']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"norm\"] = enc_cfg['NORM']\n        return ret\n\n    def forward_features(self, features):\n        multi_scale_features = []\n        num_cur_levels = 0\n        # Reverse feature maps into top-down order (from low to high resolution)\n        for idx, f in enumerate(self.in_features[::-1]):\n            x = features[f]\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            if lateral_conv is None:\n                y = output_conv(x)\n            else:\n                cur_fpn = lateral_conv(x)\n                # Following FPN implementation, we use nearest upsampling here\n                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode=\"nearest\")\n                y = output_conv(y)\n            if num_cur_levels < self.maskformer_num_feature_levels:\n                multi_scale_features.append(y)\n                num_cur_levels += 1\n        \n        mask_features = self.mask_features(y) if self.mask_on else None\n        return mask_features, None, multi_scale_features\n\n    def forward(self, features, targets=None):\n        logger = logging.getLogger(__name__)\n        logger.warning(\"Calling forward() may cause unpredicted behavior of PixelDecoder module.\")\n        return self.forward_features(features)\n\n\nclass TransformerEncoderOnly(nn.Module):\n    def __init__(\n        self,\n        d_model=512,\n        nhead=8,\n        num_encoder_layers=6,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n\n        encoder_layer = TransformerEncoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        self._reset_parameters()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, src, mask, pos_embed):\n        # flatten NxCxHxW to HWxNxC\n        bs, c, h, w = src.shape\n        src = src.flatten(2).permute(2, 0, 1)\n        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)\n        if mask is not None:\n            mask = mask.flatten(1)\n\n        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)\n        return memory.permute(1, 2, 0).view(bs, c, h, w)\n\n\n# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.\nclass TransformerEncoderPixelDecoder(BasePixelDecoder):\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        transformer_dropout: float,\n        transformer_nheads: int,\n        transformer_dim_feedforward: int,\n        transformer_enc_layers: int,\n        transformer_pre_norm: bool,\n        conv_dim: int,\n        mask_dim: int,\n        mask_on: int,\n        norm: Optional[Union[str, Callable]] = None,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            transformer_dropout: dropout probability in transformer\n            transformer_nheads: number of heads in transformer\n            transformer_dim_feedforward: dimension of feedforward network\n            transformer_enc_layers: number of transformer encoder layers\n            transformer_pre_norm: whether to use pre-layernorm or not\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n        \"\"\"\n        super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on)\n\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        feature_strides = [v.stride for k, v in input_shape]\n        feature_channels = [v.channels for k, v in input_shape]\n\n        in_channels = feature_channels[len(self.in_features) - 1]\n        self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)\n        weight_init.c2_xavier_fill(self.input_proj)\n        self.transformer = TransformerEncoderOnly(\n            d_model=conv_dim,\n            dropout=transformer_dropout,\n            nhead=transformer_nheads,\n            dim_feedforward=transformer_dim_feedforward,\n            num_encoder_layers=transformer_enc_layers,\n            normalize_before=transformer_pre_norm,\n        )\n        N_steps = conv_dim // 2\n        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)\n\n        # update layer\n        use_bias = norm == \"\"\n        output_norm = get_norm(norm, conv_dim)\n        output_conv = Conv2d(\n            conv_dim,\n            conv_dim,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            bias=use_bias,\n            norm=output_norm,\n            activation=F.relu,\n        )\n        weight_init.c2_xavier_fill(output_conv)\n        delattr(self, \"layer_{}\".format(len(self.in_features)))\n        self.add_module(\"layer_{}\".format(len(self.in_features)), output_conv)\n        self.output_convs[0] = output_conv\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret = super().from_config(cfg, input_shape)\n        ret[\"transformer_dropout\"] = dec_cfg['DROPOUT']\n        ret[\"transformer_nheads\"] = dec_cfg['NHEADS']\n        ret[\"transformer_dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']\n        ret[\"transformer_enc_layers\"] = enc_cfg['TRANSFORMER_ENC_LAYERS']  # a separate config\n        ret[\"transformer_pre_norm\"] = dec_cfg['PRE_NORM']\n\n        ret['mask_on'] = cfg['MODEL']['DECODER']['MASK']\n        return ret\n\n    def forward_features(self, features):\n        multi_scale_features = []\n        num_cur_levels = 0\n        \n        # Reverse feature maps into top-down order (from low to high resolution)\n        for idx, f in enumerate(self.in_features[::-1]):\n            x = features[f]\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            if lateral_conv is None:\n                transformer = self.input_proj(x)\n                pos = self.pe_layer(x)\n                transformer = self.transformer(transformer, None, pos)\n                y = output_conv(transformer)\n                # save intermediate feature as input to Transformer decoder\n                transformer_encoder_features = transformer\n            else:\n                cur_fpn = lateral_conv(x)\n                # Following FPN implementation, we use nearest upsampling here\n                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode=\"nearest\")\n                y = output_conv(y)\n            if num_cur_levels < self.maskformer_num_feature_levels:\n                multi_scale_features.append(y)\n                num_cur_levels += 1\n\n        mask_features = self.mask_features(y) if self.mask_on else None\n        return mask_features, transformer_encoder_features, multi_scale_features\n\n    def forward(self, features, targets=None):\n        logger = logging.getLogger(__name__)\n        logger.warning(\"Calling forward() may cause unpredicted behavior of PixelDecoder module.\")\n        return self.forward_features(features)\n\n\n\n@register_encoder\ndef get_transformer_encoder_fpn(cfg, input_shape):\n    \"\"\"\n    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.\n    \"\"\"\n    model = TransformerEncoderPixelDecoder(cfg, input_shape)    \n    forward_features = getattr(model, \"forward_features\", None)\n    if not callable(forward_features):\n        raise ValueError(\n            \"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. \"\n            f\"Please implement forward_features for {name} to only return mask features.\"\n        )\n    return model"
  },
  {
    "path": "llava/model/openseed/body/openseed_head.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------------\nimport logging\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch import nn\n\nfrom detectron2.layers import Conv2d, ShapeSpec, get_norm\nfrom detectron2.modeling import SEM_SEG_HEADS_REGISTRY\n\nfrom .registry import register_body\nfrom .encoder import build_encoder\nfrom .decoder import build_decoder\nfrom ..utils import configurable\n\n\nclass OpenSeeDHead(nn.Module):\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        num_classes: int,\n        pixel_decoder: nn.Module,\n        loss_weight: float = 1.0,\n        ignore_value: int = -1,\n        transformer_predictor: nn.Module,\n    ):\n        \"\"\"\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            num_classes: number of classes to predict\n            pixel_decoder: the pixel decoder module\n            loss_weight: loss weight\n            ignore_value: category id to be ignored during training.\n            transformer_predictor: the transformer decoder that makes prediction\n            transformer_in_feature: input feature name to the transformer_predictor\n        \"\"\"\n        super().__init__()\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]\n        self.ignore_value = ignore_value\n        self.common_stride = 4\n        self.loss_weight = loss_weight\n\n        self.pixel_decoder = pixel_decoder\n        self.predictor = transformer_predictor\n\n        self.num_classes = num_classes\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n        transformer_predictor_in_channels = enc_cfg['CONVS_DIM']\n\n        return {\n            \"input_shape\": {\n                k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n            },\n            \"ignore_value\": enc_cfg['IGNORE_VALUE'],\n            \"num_classes\": enc_cfg.get('NUM_CLASSES', None),\n            \"pixel_decoder\": build_encoder(cfg, input_shape),\n            \"loss_weight\": enc_cfg['LOSS_WEIGHT'],\n            \"transformer_predictor\": build_decoder(\n                cfg,\n                transformer_predictor_in_channels,\n                mask_classification=True,\n                extra=extra,\n            ),\n        }\n\n    def forward(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, task='seg', extra={},default_text_embeddings=None):\n        mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask)\n\n        predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets,\n                                         target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra,default_text_embeddings=default_text_embeddings)\n        return predictions\n\n\n@register_body\ndef get_maskdino_head(cfg, input_shape, lang_encoder, extra):\n    return OpenSeeDHead(cfg, input_shape, lang_encoder, extra)"
  },
  {
    "path": "llava/model/openseed/body/registry.py",
    "content": "_model_entrypoints = {}\n\n\ndef register_body(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/openseed/body/transformer_blocks.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py\n\"\"\"\nTransformer class.\n\nCopy-paste from torch.nn.Transformer with modifications:\n    * positional encodings are passed in MHattention\n    * extra LN at the end of encoder is removed\n    * decoder returns a stack of activations from all decoding layers\n\"\"\"\nimport copy\nfrom typing import List, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        d_model=512,\n        nhead=8,\n        num_encoder_layers=6,\n        num_decoder_layers=6,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n        return_intermediate_dec=False,\n    ):\n        super().__init__()\n\n        encoder_layer = TransformerEncoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        decoder_layer = TransformerDecoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        decoder_norm = nn.LayerNorm(d_model)\n        self.decoder = TransformerDecoder(\n            decoder_layer,\n            num_decoder_layers,\n            decoder_norm,\n            return_intermediate=return_intermediate_dec,\n        )\n\n        self._reset_parameters()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, src, mask, query_embed, pos_embed):\n        # flatten NxCxHxW to HWxNxC\n        bs, c, h, w = src.shape\n        src = src.flatten(2).permute(2, 0, 1)\n        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)\n        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)\n        if mask is not None:\n            mask = mask.flatten(1)\n\n        tgt = torch.zeros_like(query_embed)\n        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)\n        hs = self.decoder(\n            tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed\n        )\n        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)\n\n\nclass TransformerEncoder(nn.Module):\n    def __init__(self, encoder_layer, num_layers, norm=None):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n\n    def forward(\n        self,\n        src,\n        mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        output = src\n\n        for layer in self.layers:\n            output = layer(\n                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos\n            )\n\n        if self.norm is not None:\n            output = self.norm(output)\n\n        return output\n\n\nclass TransformerDecoder(nn.Module):\n    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n        super().__init__()\n        self.layers = _get_clones(decoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n        self.return_intermediate = return_intermediate\n\n    def forward(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        output = tgt\n\n        intermediate = []\n\n        for layer in self.layers:\n            output = layer(\n                output,\n                memory,\n                tgt_mask=tgt_mask,\n                memory_mask=memory_mask,\n                tgt_key_padding_mask=tgt_key_padding_mask,\n                memory_key_padding_mask=memory_key_padding_mask,\n                pos=pos,\n                query_pos=query_pos,\n            )\n            if self.return_intermediate:\n                intermediate.append(self.norm(output))\n\n        if self.norm is not None:\n            output = self.norm(output)\n            if self.return_intermediate:\n                intermediate.pop()\n                intermediate.append(output)\n\n        if self.return_intermediate:\n            return torch.stack(intermediate)\n\n        return output.unsqueeze(0)\n\n\nclass TransformerEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        nhead,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        q = k = self.with_pos_embed(src, pos)\n\n        src2 = self.self_attn(\n            q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n        )[0]\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n        src = src + self.dropout2(src2)\n        src = self.norm2(src)\n        return src\n\n    def forward_pre(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        src2 = self.norm1(src)\n        q = k = self.with_pos_embed(src2, pos)\n        src2 = self.self_attn(\n            q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n        )[0]\n        src = src + self.dropout1(src2)\n        src2 = self.norm2(src)\n        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n        src = src + self.dropout2(src2)\n        return src\n\n    def forward(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        if self.normalize_before:\n            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n        return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n\n\nclass TransformerDecoderLayer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        nhead,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.norm3 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.dropout3 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        q = k = self.with_pos_embed(tgt, query_pos)\n        tgt2 = self.self_attn(\n            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask\n        )[0]\n        tgt = tgt + self.dropout1(tgt2)\n        tgt = self.norm1(tgt)\n        tgt2 = self.multihead_attn(\n            query=self.with_pos_embed(tgt, query_pos),\n            key=self.with_pos_embed(memory, pos),\n            value=memory,\n            attn_mask=memory_mask,\n            key_padding_mask=memory_key_padding_mask,\n        )[0]\n        tgt = tgt + self.dropout2(tgt2)\n        tgt = self.norm2(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout3(tgt2)\n        tgt = self.norm3(tgt)\n        return tgt\n\n    def forward_pre(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        tgt2 = self.norm1(tgt)\n        q = k = self.with_pos_embed(tgt2, query_pos)\n        tgt2 = self.self_attn(\n            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask\n        )[0]\n        tgt = tgt + self.dropout1(tgt2)\n        tgt2 = self.norm2(tgt)\n        tgt2 = self.multihead_attn(\n            query=self.with_pos_embed(tgt2, query_pos),\n            key=self.with_pos_embed(memory, pos),\n            value=memory,\n            attn_mask=memory_mask,\n            key_padding_mask=memory_key_padding_mask,\n        )[0]\n        tgt = tgt + self.dropout2(tgt2)\n        tgt2 = self.norm3(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n        tgt = tgt + self.dropout3(tgt2)\n        return tgt\n\n    def forward(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        if self.normalize_before:\n            return self.forward_pre(\n                tgt,\n                memory,\n                tgt_mask,\n                memory_mask,\n                tgt_key_padding_mask,\n                memory_key_padding_mask,\n                pos,\n                query_pos,\n            )\n        return self.forward_post(\n            tgt,\n            memory,\n            tgt_mask,\n            memory_mask,\n            tgt_key_padding_mask,\n            memory_key_padding_mask,\n            pos,\n            query_pos,\n        )\n\n\ndef _get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(f\"activation should be relu/gelu, not {activation}.\")\n"
  },
  {
    "path": "llava/model/openseed/language/LangEncoder/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .build import build_lang_encoder\nfrom .build import build_tokenizer\n\nfrom .transformer import *"
  },
  {
    "path": "llava/model/openseed/language/LangEncoder/build.py",
    "content": "import os\n\nfrom transformers import CLIPTokenizer, CLIPTokenizerFast\nfrom transformers import AutoTokenizer\n\nfrom .registry import lang_encoders\nfrom .registry import is_lang_encoder\n\n\ndef build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):\n    model_name = config_encoder['NAME']\n\n    if not is_lang_encoder(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)\n\n\ndef build_tokenizer(config_encoder):\n    tokenizer = None\n    os.environ['TOKENIZERS_PARALLELISM'] = 'true'\n    if config_encoder['TOKENIZER'] == 'clip':\n        pretrained_tokenizer = config_encoder.get(\n            'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'\n        )\n        tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)\n        tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})\n    elif config_encoder['TOKENIZER'] == 'clip-fast':\n        pretrained_tokenizer = config_encoder.get(\n            'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'\n        )\n        tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])\n\n    return tokenizer\n"
  },
  {
    "path": "llava/model/openseed/language/LangEncoder/registry.py",
    "content": "_lang_encoders = {}\n\n\ndef register_lang_encoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n\n    _lang_encoders[model_name] = fn\n\n    return fn\n\n\ndef lang_encoders(model_name):\n    return _lang_encoders[model_name]\n\n\ndef is_lang_encoder(model_name):\n    return model_name in _lang_encoders\n"
  },
  {
    "path": "llava/model/openseed/language/LangEncoder/transformer.py",
    "content": "from collections import OrderedDict\nfrom typing import Tuple, Union\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom timm.models.layers import DropPath, trunc_normal_\n\nfrom .registry import register_lang_encoder\nfrom detectron2.utils.comm import is_main_process\nfrom utils.model import register_norm_module\n\nlogger = logging.getLogger(__name__)\n\n\n@register_norm_module\nclass LayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        \"\"\"Construct a layernorm module in the TF style (epsilon inside the square root).\n        \"\"\"\n        super(LayerNorm, self).__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.bias = nn.Parameter(torch.zeros(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        pdtype = x.dtype\n        x = x.float()\n        u = x.mean(-1, keepdim=True)\n        s = (x - u).pow(2).mean(-1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.variance_epsilon)\n        return self.weight * x.to(pdtype) + self.bias\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self,\n                 d_model: int,\n                 n_head: int,\n                 attn_mask: torch.Tensor = None,\n                 drop_path: float = 0.0):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n            (\"gelu\", QuickGELU()),\n            (\"c_proj\", nn.Linear(d_model * 4, d_model))\n        ]))\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):\n        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \\\n            if self.attn_mask is not None else None\n\n\n        return self.attn(\n            x, x, x,\n            key_padding_mask=key_padding_mask,\n            need_weights=False,\n            attn_mask=self.attn_mask\n        )[0]\n\n    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):\n        x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))\n        x = x + self.drop_path(self.mlp(self.ln_2(x)))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self,\n                 context_length: int,\n                 vocab_size: int,\n                 width: int,\n                 layers: int,\n                 heads: int,\n                 drop_path: float = 0.0,\n                 autogressive: bool =True):\n        super().__init__()\n\n        self.token_embedding = nn.Embedding(vocab_size, width)\n\n        self.context_length = context_length\n        self.positional_embedding = nn.Parameter(\n            torch.empty(self.context_length, width)\n        )\n\n        self.width = width\n        self.layers = layers\n        self.autogressive = autogressive\n        attn_mask = self.build_attention_mask() if autogressive else None\n        dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]  # stochastic depth decay rule\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(width, heads, attn_mask, dpr[i])\n                for i in range(layers)\n            ]\n        )\n\n        self.ln_final = LayerNorm(width)\n\n        trunc_normal_(self.positional_embedding, std=.02)\n        # nn.init.normal_(self.token_embedding, std=.02)\n        trunc_normal_(self.token_embedding.weight, std=.02)\n        self.apply(self._init_weights)\n\n    @property\n    def dim_out(self):\n        return self.width\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d)):\n            if is_main_process():\n                logger.info('=> init weight of Linear/Conv2d from trunc norm')\n            trunc_normal_(m.weight, std=0.02)\n            if m.bias is not None:\n                if is_main_process():\n                    logger.info('=> init bias of Linear/Conv2d to zeros')\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):\n            nn.init.constant_(m.bias, 0)\n\n    def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained, map_location='cpu')\n            logging.info(f'=> loading pretrained model {pretrained}')\n            model_dict = self.state_dict()\n            stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x\n            pretrained_dict = {\n                stripped_key(k): v for k, v in pretrained_dict.items()\n                if stripped_key(k) in model_dict.keys()\n            }\n            need_init_state_dict = {}\n            for k, v in pretrained_dict.items():\n                need_init = (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                if need_init:\n                    if verbose:\n                        logger.info(f'=> init {k} from {pretrained}')\n\n                    if 'positional_embedding' in k and v.size() != model_dict[k].size():\n                        positional_embedding_pretrained = v\n                        positional_embedding_current = model_dict[k]\n                        L1, nH1 = positional_embedding_pretrained.size()\n                        L2, nH2 = positional_embedding_current.size()\n                        if nH1 != nH2:\n                            logger.info(f\"Error in loading {k}, passing\")\n                        else:\n                            if L1 != L2:\n                                logger.info(\n                                    '=> load_pretrained: resized variant: {} to {}'\n                                        .format((L1, nH1), (L2, nH2))\n                                )\n\n                                posemb = positional_embedding_pretrained.float()\n                                posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)\n                                posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')\n                                posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)\n                                v = posemb_grid\n\n                    need_init_state_dict[k] = v\n\n            self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\n            'positional_embedding',\n            'token_embedding',\n        }\n\n    def forward(self, input_ids, attention_mask=None):\n        key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None\n        # key_padding_mask = (input_ids == 0) if not self.autogressive else None\n        x = self.token_embedding(input_ids)  # [batch_size, n_ctx, d_model]\n        x = x + self.positional_embedding\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        for block in self.resblocks:\n            x = block(x, key_padding_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        x = self.ln_final(x)\n\n        return {'last_hidden_state': x}\n\n\n@register_lang_encoder\ndef lang_encoder(config_encoder, tokenizer, verbose, **kwargs):\n    transformer = Transformer(\n        context_length=config_encoder['CONTEXT_LENGTH'],\n        vocab_size=tokenizer.vocab_size,\n        width=config_encoder['WIDTH'],\n        layers=config_encoder['LAYERS'],\n        heads=config_encoder['HEADS'],\n        autogressive=config_encoder.get('AUTOGRESSIVE', True)\n    )\n\n    if config_encoder.get('LOAD_PRETRAINED', False):\n        transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))\n    return transformer\n"
  },
  {
    "path": "llava/model/openseed/language/__init__.py",
    "content": "# from .vlpencoder import *\n# from .encoder import *\n# # from .loss import *\n# from .build import build_language_encoder"
  },
  {
    "path": "llava/model/openseed/language/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\n\ndef build_language_encoder(config, **kwargs):\n    model_name = config['MODEL']['TEXT']['ARCH']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/openseed/language/encoder.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\n\n\nclass LanguageEncoder(nn.Module):\n\n    @configurable\n    def __init__(\n        self,\n        tokenizer,\n        tokenizer_type,\n        lang_encoder,\n        lang_projection,\n        max_token_num,\n    ):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.tokenizer_type = tokenizer_type\n        self.lang_encoder = lang_encoder\n        self.lang_proj = lang_projection\n        self.max_token_num = max_token_num\n        self.logit_scale = nn.Parameter(torch.ones([]))\n\n    @classmethod\n    def from_config(cls, cfg):\n        # build up text encoder\n        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])\n        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']\n        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])\n        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']\n        \n        dim_lang = cfg['MODEL']['TEXT']['WIDTH']\n        dim_projection = cfg['MODEL']['DIM_PROJ']\n        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))\n        trunc_normal_(lang_projection, std=.02)\n\n        return {\n            \"tokenizer\": tokenizer,\n            \"tokenizer_type\": tokenizer_type,\n            \"lang_encoder\": lang_encoder,\n            \"lang_projection\": lang_projection,\n            \"max_token_num\": max_token_num,\n        }\n\n    # @torch.no_grad()\n    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):\n        if not is_eval:\n            if prompt:\n                # randomly sample one template\n                arbitary_concepts = [\n                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \\\n                    for label in range(len(class_names))\n                ]\n                if add_bgd:\n                    arbitary_concepts.append(\"A background in coco.\")\n            else:\n                arbitary_concepts = class_names\n            \n            input_ids = []\n            attention_masks = []\n            for txt in arbitary_concepts:\n                tokens = self.tokenizer(\n                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                )\n                tokens['input_ids'].squeeze_()\n                tokens['attention_mask'].squeeze_()\n\n                input_ids.append(tokens['input_ids'])\n                attention_masks.append(tokens['attention_mask'])\n\n            arbitary_tokens = torch.stack(input_ids)\n            arbitary_attention_masks = torch.stack(attention_masks)\n\n            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)\n            setattr(self, '{}_text_embeddings'.format(name), text_emb)\n        else:\n            with torch.no_grad():\n                def extract_mean_emb(txts):\n                    tokens = self.tokenizer(\n                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                    )\n                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)\n                    clss_embedding = clss_embedding.mean(dim=0)\n                    clss_embedding /= clss_embedding.norm()\n                    return clss_embedding\n\n                templates = get_prompt_templates()\n                clss_embeddings = []\n                for clss in class_names:\n                    txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                if add_bgd:\n                    txts = [\"A background in coco.\"]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                text_emb = torch.stack(clss_embeddings, dim=0)\n                setattr(self, '{}_text_embeddings'.format(name), text_emb)\n\n    # @torch.no_grad()\n    def forward_language(self, texts, norm=True):\n        x = self.lang_encoder(*texts)\n        x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            x = x[:, 0]\n\n        x = x @ self.lang_proj\n        if norm:\n            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)\n        return x\n    \n    def compute_similarity(self, v_emb, name='default'):\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = getattr(self, '{}_text_embeddings'.format(name))\n        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    return LanguageEncoder(cfg)"
  },
  {
    "path": "llava/model/openseed/language/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_model(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/openseed/language/vlpencoder.py",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\n\n\nclass LanguageEncoder(nn.Module):\n\n    @configurable\n    def __init__(\n        self,\n        tokenizer,\n        tokenizer_type,\n        lang_encoder,\n        lang_projection,\n        max_token_num,\n        queue_operator,\n    ):\n        super().__init__()\n        # seg\n        self.tokenizer = tokenizer\n        self.tokenizer_type = tokenizer_type\n        self.lang_encoder = lang_encoder\n        self.lang_proj = lang_projection\n        self.max_token_num = max_token_num\n        self.logit_scale = nn.Parameter(torch.ones([]))\n        \n        # captioning & retrieval\n        for key, value in queue_operator.items():\n            self.register_buffer(key, value)\n            \n\n    @classmethod\n    def from_config(cls, cfg):\n        # build up text encoder for seg\n        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])\n        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']\n        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])\n        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']\n        \n        dim_lang = cfg['MODEL']['TEXT']['WIDTH']\n        dim_projection = cfg['MODEL']['DIM_PROJ']\n        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))\n        trunc_normal_(lang_projection, std=.02)\n\n        # tested not working better      \n        queue_operator = {}\n\n        return {\n            \"tokenizer\": tokenizer,\n            \"tokenizer_type\": tokenizer_type,\n            \"lang_encoder\": lang_encoder,\n            \"lang_projection\": lang_projection,\n            \"max_token_num\": max_token_num,\n            \"queue_operator\": queue_operator,\n        }\n\n    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):\n        if not is_eval:\n            if prompt:\n                # randomly sample one template\n                arbitary_concepts = [\n                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \\\n                    for label in range(len(class_names))\n                ]\n                if add_bgd:\n                    arbitary_concepts.append(\"A background in coco.\")\n            else:\n                arbitary_concepts = class_names\n            \n            input_ids = []\n            attention_masks = []\n            for txt in arbitary_concepts:\n                tokens = self.tokenizer(\n                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                )\n                tokens['input_ids'].squeeze_()\n                tokens['attention_mask'].squeeze_()\n\n                input_ids.append(tokens['input_ids'])\n                attention_masks.append(tokens['attention_mask'])\n\n            arbitary_tokens = torch.stack(input_ids)\n            arbitary_attention_masks = torch.stack(attention_masks)\n\n            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)\n            setattr(self, '{}_text_embeddings'.format(name), text_emb)\n        else:\n            with torch.no_grad():\n                def extract_mean_emb(txts):\n                    tokens = self.tokenizer(\n                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                    )\n                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)\n                    clss_embedding = clss_embedding.mean(dim=0)\n                    clss_embedding /= clss_embedding.norm()\n                    return clss_embedding\n\n                templates = get_prompt_templates()\n                clss_embeddings = []\n                if prompt:\n                    for clss in class_names:\n                        txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]\n                        clss_embeddings.append(extract_mean_emb(txts))\n                else:\n                    clss_embeddings.append(extract_mean_emb(class_names))\n\n                if add_bgd:\n                    txts = [\"A background in coco.\"]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                text_emb = torch.stack(clss_embeddings, dim=0)\n                setattr(self, '{}_text_embeddings'.format(name), text_emb)\n\n    def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):\n        if not token:\n            tokens = self.tokenizer(\n                txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n            )\n            tokens = {key: value.cuda() for key, value in tokens.items()}\n        else:\n            tokens = txts\n        token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)\n        ret = {\"tokens\": tokens,\n                \"token_emb\": token_emb,\n                \"class_emb\": class_emb,}\n        setattr(self, '{}_token_embeddings'.format(name), ret)\n        return ret\n\n    def forward_language(self, texts, norm=True):\n        x = self.lang_encoder(*texts)\n        x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            x = x[:, 0]\n\n        x = x @ self.lang_proj\n        if norm:\n            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)\n        return x\n    \n    def forward_language_token(self, texts, norm=False):\n        x = self.lang_encoder(*texts)\n        token_x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            class_x = token_x[:, 0]\n\n        class_x = class_x @ self.lang_proj\n        token_x = token_x @ self.lang_proj\n\n        if norm:\n            class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)\n            token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)\n\n        return token_x, class_x\n    \n    def compute_similarity(self, v_emb, name='default', fake=False):\n        if fake:\n            return None\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = getattr(self, '{}_text_embeddings'.format(name))\n        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    return LanguageEncoder(cfg)"
  },
  {
    "path": "llava/model/openseed/modules/__init__.py",
    "content": "from .point_features import *\nfrom .position_encoding import *\nfrom .postprocessing import *\nfrom .attention import *\nfrom .matcher import *\nfrom .criterion import *"
  },
  {
    "path": "llava/model/openseed/modules/attention.py",
    "content": "import warnings\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nfrom torch.nn.init import constant_, xavier_normal_, xavier_uniform_\nfrom torch.nn.parameter import Parameter\nfrom torch.overrides import has_torch_function, handle_torch_function\nfrom torch.nn.functional import pad, linear, softmax, dropout\n\n\ndef multi_head_attention_forward(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    embed_dim_to_check: int,\n    num_heads: int,\n    in_proj_weight: Tensor,\n    in_proj_bias: Tensor,\n    bias_k: Optional[Tensor],\n    bias_v: Optional[Tensor],\n    add_zero_attn: bool,\n    dropout_p: float,\n    out_proj_weight: Tensor,\n    out_proj_bias: Tensor,\n    training: bool = True,\n    key_padding_mask: Optional[Tensor] = None,\n    need_weights: bool = True,\n    attn_mask: Optional[Tensor] = None,\n    use_separate_proj_weight: bool = False,\n    q_proj_weight: Optional[Tensor] = None,\n    k_proj_weight: Optional[Tensor] = None,\n    v_proj_weight: Optional[Tensor] = None,\n    static_k: Optional[Tensor] = None,\n    static_v: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"\n    Args:\n        query, key, value: map a query and a set of key-value pairs to an output.\n            See \"Attention Is All You Need\" for more details.\n        embed_dim_to_check: total dimension of the model.\n        num_heads: parallel attention heads.\n        in_proj_weight, in_proj_bias: input projection weight and bias.\n        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.\n        add_zero_attn: add a new batch of zeros to the key and\n                       value sequences at dim=1.\n        dropout_p: probability of an element to be zeroed.\n        out_proj_weight, out_proj_bias: the output projection weight and bias.\n        training: apply dropout if is ``True``.\n        key_padding_mask: if provided, specified padding elements in the key will\n            be ignored by the attention. This is an binary mask. When the value is True,\n            the corresponding value on the attention layer will be filled with -inf.\n        need_weights: output attn_output_weights.\n        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n            the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n        use_separate_proj_weight: the function accept the proj. weights for query, key,\n            and value in different forms. If false, in_proj_weight will be used, which is\n            a combination of q_proj_weight, k_proj_weight, v_proj_weight.\n        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.\n        static_k, static_v: static key and value used for attention operators.\n\n\n    Shape:\n        Inputs:\n        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions\n          will be unchanged. If a BoolTensor is provided, the positions with the\n          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.\n          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,\n          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked\n          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n          is provided, it will be added to the attention weight.\n        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n\n        Outputs:\n        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n          E is the embedding dimension.\n        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n          L is the target sequence length, S is the source sequence length.\n    \"\"\"\n    tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)\n    if has_torch_function(tens_ops):\n        return handle_torch_function(\n            multi_head_attention_forward,\n            tens_ops,\n            query,\n            key,\n            value,\n            embed_dim_to_check,\n            num_heads,\n            in_proj_weight,\n            in_proj_bias,\n            bias_k,\n            bias_v,\n            add_zero_attn,\n            dropout_p,\n            out_proj_weight,\n            out_proj_bias,\n            training=training,\n            key_padding_mask=key_padding_mask,\n            need_weights=need_weights,\n            attn_mask=attn_mask,\n            use_separate_proj_weight=use_separate_proj_weight,\n            q_proj_weight=q_proj_weight,\n            k_proj_weight=k_proj_weight,\n            v_proj_weight=v_proj_weight,\n            static_k=static_k,\n            static_v=static_v,\n        )\n    tgt_len, bsz, embed_dim = query.size()\n    assert embed_dim == embed_dim_to_check\n    # allow MHA to have different sizes for the feature dimension\n    assert key.size(0) == value.size(0) and key.size(1) == value.size(1)\n\n    head_dim = embed_dim // num_heads\n    assert head_dim * num_heads == embed_dim, \"embed_dim must be divisible by num_heads\"\n    scaling = float(head_dim) ** -0.5\n\n    if not use_separate_proj_weight:\n        if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):\n            # self-attention\n            q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)\n\n        elif key is value or torch.equal(key, value):\n            # encoder-decoder attention\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = 0\n            _end = embed_dim\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            q = linear(query, _w, _b)\n\n            if key is None:\n                assert value is None\n                k = None\n                v = None\n            else:\n\n                # This is inline in_proj function with in_proj_weight and in_proj_bias\n                _b = in_proj_bias\n                _start = embed_dim\n                _end = None\n                _w = in_proj_weight[_start:, :]\n                if _b is not None:\n                    _b = _b[_start:]\n                k, v = linear(key, _w, _b).chunk(2, dim=-1)\n\n        else:\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = 0\n            _end = embed_dim\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            q = linear(query, _w, _b)\n\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = embed_dim\n            _end = embed_dim * 2\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            k = linear(key, _w, _b)\n\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = embed_dim * 2\n            _end = None\n            _w = in_proj_weight[_start:, :]\n            if _b is not None:\n                _b = _b[_start:]\n            v = linear(value, _w, _b)\n    else:\n        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)\n        len1, len2 = q_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == query.size(-1)\n\n        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)\n        len1, len2 = k_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == key.size(-1)\n\n        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)\n        len1, len2 = v_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == value.size(-1)\n\n        if in_proj_bias is not None:\n            q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])\n            k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])\n            v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])\n        else:\n            q = linear(query, q_proj_weight_non_opt, in_proj_bias)\n            k = linear(key, k_proj_weight_non_opt, in_proj_bias)\n            v = linear(value, v_proj_weight_non_opt, in_proj_bias)\n    q = q * scaling\n\n    if attn_mask is not None:\n        assert (\n            attn_mask.dtype == torch.float32\n            or attn_mask.dtype == torch.float64\n            or attn_mask.dtype == torch.float16\n            or attn_mask.dtype == torch.uint8\n            or attn_mask.dtype == torch.bool\n        ), \"Only float, byte, and bool types are supported for attn_mask, not {}\".format(attn_mask.dtype)\n        if attn_mask.dtype == torch.uint8:\n            warnings.warn(\"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n            attn_mask = attn_mask.to(torch.bool)\n\n        if attn_mask.dim() == 2:\n            attn_mask = attn_mask.unsqueeze(0)\n            if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:\n                raise RuntimeError(\"The size of the 2D attn_mask is not correct.\")\n        elif attn_mask.dim() == 3:\n            if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:\n                raise RuntimeError(\"The size of the 3D attn_mask is not correct.\")\n        else:\n            raise RuntimeError(\"attn_mask's dimension {} is not supported\".format(attn_mask.dim()))\n        # attn_mask's dim is 3 now.\n\n    # convert ByteTensor key_padding_mask to bool\n    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:\n        warnings.warn(\n            \"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\"\n        )\n        key_padding_mask = key_padding_mask.to(torch.bool)\n\n    if bias_k is not None and bias_v is not None:\n        if static_k is None and static_v is None:\n            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])\n            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])\n            if attn_mask is not None:\n                attn_mask = pad(attn_mask, (0, 1))\n            if key_padding_mask is not None:\n                key_padding_mask = pad(key_padding_mask, (0, 1))\n        else:\n            assert static_k is None, \"bias cannot be added to static key.\"\n            assert static_v is None, \"bias cannot be added to static value.\"\n    else:\n        assert bias_k is None\n        assert bias_v is None\n\n    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)\n    if k is not None:\n        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n    if v is not None:\n        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n\n    if static_k is not None:\n        assert static_k.size(0) == bsz * num_heads\n        assert static_k.size(2) == head_dim\n        k = static_k\n\n    if static_v is not None:\n        assert static_v.size(0) == bsz * num_heads\n        assert static_v.size(2) == head_dim\n        v = static_v\n\n    src_len = k.size(1)\n\n    if key_padding_mask is not None:\n        # assert key_padding_mask.size(0) == bsz\n        assert key_padding_mask.size(1) == src_len\n\n    if add_zero_attn:\n        src_len += 1\n        k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)\n        v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)\n        if attn_mask is not None:\n            attn_mask = pad(attn_mask, (0, 1))\n        if key_padding_mask is not None:\n            key_padding_mask = pad(key_padding_mask, (0, 1))\n\n    attn_output_weights = torch.bmm(q, k.transpose(1, 2))\n    assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]\n\n    if attn_mask is not None:\n        if attn_mask.dtype == torch.bool:\n            attn_output_weights.masked_fill_(attn_mask, float(\"-inf\"))\n        else:\n            attn_output_weights += attn_mask\n\n    if key_padding_mask is not None:\n        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n        attn_output_weights = attn_output_weights.masked_fill(\n            key_padding_mask.unsqueeze(1),\n            float(\"-inf\"),\n        )\n        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)\n\n    attn_output_weights = softmax(attn_output_weights, dim=-1)\n    attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)\n\n    attn_output = torch.bmm(attn_output_weights, v)\n    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]\n    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)\n    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)\n\n    if need_weights:\n        # average attention weights over heads\n        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n        return attn_output, attn_output_weights.sum(dim=1) / num_heads\n    else:\n        return attn_output, None\n\n\n# This class exists solely for Transformer; it has an annotation stating\n# that bias is never None, which appeases TorchScript\nclass _LinearWithBias(nn.Linear):\n    bias: Tensor  # type: ignore\n\n    def __init__(self, in_features: int, out_features: int) -> None:\n        super().__init__(in_features, out_features, bias=True)  # type: ignore\n\n\nclass MultiheadAttention(nn.Module):\n    r\"\"\"Allows the model to jointly attend to information\n    from different representation subspaces.\n    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_\n\n    .. math::\n        \\text{MultiHead}(Q, K, V) = \\text{Concat}(head_1,\\dots,head_h)W^O\n\n    where :math:`head_i = \\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.\n\n    Args:\n        embed_dim: total dimension of the model.\n        num_heads: parallel attention heads.\n        dropout: a Dropout layer on attn_output_weights. Default: 0.0.\n        bias: add bias as module parameter. Default: True.\n        add_bias_kv: add bias to the key and value sequences at dim=0.\n        add_zero_attn: add a new batch of zeros to the key and\n                       value sequences at dim=1.\n        kdim: total number of features in key. Default: None.\n        vdim: total number of features in value. Default: None.\n\n    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set\n    to :attr:`embed_dim` such that query, key, and value have the same\n    number of features.\n\n    Examples::\n\n        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)\n        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)\n    \"\"\"\n    bias_k: Optional[torch.Tensor]\n    bias_v: Optional[torch.Tensor]\n\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):\n        super(MultiheadAttention, self).__init__()\n        self.embed_dim = embed_dim\n        self.kdim = kdim if kdim is not None else embed_dim\n        self.vdim = vdim if vdim is not None else embed_dim\n        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim\n\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n\n        if self._qkv_same_embed_dim is False:\n            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))\n            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))\n            self.register_parameter('in_proj_weight', None)\n        else:\n            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))\n            self.register_parameter('q_proj_weight', None)\n            self.register_parameter('k_proj_weight', None)\n            self.register_parameter('v_proj_weight', None)\n\n        if bias:\n            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n        else:\n            self.register_parameter('in_proj_bias', None)\n        self.out_proj = _LinearWithBias(embed_dim, embed_dim)\n\n        if add_bias_kv:\n            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))\n            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))\n        else:\n            self.bias_k = self.bias_v = None\n\n        self.add_zero_attn = add_zero_attn\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        if self._qkv_same_embed_dim:\n            xavier_uniform_(self.in_proj_weight)\n        else:\n            xavier_uniform_(self.q_proj_weight)\n            xavier_uniform_(self.k_proj_weight)\n            xavier_uniform_(self.v_proj_weight)\n\n        if self.in_proj_bias is not None:\n            constant_(self.in_proj_bias, 0.)\n            constant_(self.out_proj.bias, 0.)\n        if self.bias_k is not None:\n            xavier_normal_(self.bias_k)\n        if self.bias_v is not None:\n            xavier_normal_(self.bias_v)\n\n    def __setstate__(self, state):\n        # Support loading old MultiheadAttention checkpoints generated by v1.1.0\n        if '_qkv_same_embed_dim' not in state:\n            state['_qkv_same_embed_dim'] = True\n\n        super(MultiheadAttention, self).__setstate__(state)\n\n    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,\n                need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:\n        r\"\"\"\n    Args:\n        query, key, value: map a query and a set of key-value pairs to an output.\n            See \"Attention Is All You Need\" for more details.\n        key_padding_mask: if provided, specified padding elements in the key will\n            be ignored by the attention. When given a binary mask and a value is True,\n            the corresponding value on the attention layer will be ignored. When given\n            a byte mask and a value is non-zero, the corresponding value on the attention\n            layer will be ignored\n        need_weights: output attn_output_weights.\n        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n            the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n\n    Shapes for inputs:\n        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n          If a ByteTensor is provided, the non-zero positions will be ignored while the position\n          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the\n          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n        - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the\n          source sequence length.\n\n          If a 3D mask: :math:`(N\\cdot\\text{num\\_heads}, L, S)` where N is the batch size, L is the target sequence\n          length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend\n          the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n          is provided, it will be added to the attention weight.\n\n    Shapes for outputs:\n        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n          E is the embedding dimension.\n        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n          L is the target sequence length, S is the source sequence length.\n        \"\"\"\n        if not self._qkv_same_embed_dim:\n            return multi_head_attention_forward(\n                query, key, value, self.embed_dim, self.num_heads,\n                self.in_proj_weight, self.in_proj_bias,\n                self.bias_k, self.bias_v, self.add_zero_attn,\n                self.dropout, self.out_proj.weight, self.out_proj.bias,\n                training=self.training,\n                key_padding_mask=key_padding_mask, need_weights=need_weights,\n                attn_mask=attn_mask, use_separate_proj_weight=True,\n                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,\n                v_proj_weight=self.v_proj_weight)\n        else:\n            return multi_head_attention_forward(\n                query, key, value, self.embed_dim, self.num_heads,\n                self.in_proj_weight, self.in_proj_bias,\n                self.bias_k, self.bias_v, self.add_zero_attn,\n                self.dropout, self.out_proj.weight, self.out_proj.bias,\n                training=self.training,\n                key_padding_mask=key_padding_mask, need_weights=need_weights,\n                attn_mask=attn_mask)"
  },
  {
    "path": "llava/model/openseed/modules/criterion.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and  Feng Li.\n# ------------------------------------------------------------------------\n\"\"\"\nMaskFormer criterion.\n\"\"\"\nimport logging\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom timm.loss import SoftTargetCrossEntropy\nfrom detectron2.utils.comm import get_world_size\nfrom detectron2.projects.point_rend.point_features import (\n    get_uncertain_point_coords_with_randomness,\n    point_sample,\n)\n\nfrom ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis\nfrom ..utils import box_ops\nimport random\n\n\ndef sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):\n    \"\"\"\n    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n        alpha: (optional) Weighting factor in range (0,1) to balance\n                positive vs negative examples. Default = -1 (no weighting).\n        gamma: Exponent of the modulating factor (1 - p_t) to\n               balance easy vs hard examples.\n    Returns:\n        Loss tensor\n    \"\"\"\n    prob = inputs.sigmoid()\n    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    p_t = prob * targets + (1 - prob) * (1 - targets)\n    loss = ce_loss * ((1 - p_t) ** gamma)\n\n    if alpha >= 0:\n        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)\n        loss = alpha_t * loss\n\n\n    return loss.mean(-1).mean(-1).sum()*10. / num_boxes\n\n\ndef dice_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n    ):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * (inputs * targets).sum(-1)\n    denominator = inputs.sum(-1) + targets.sum(-1)\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    return loss.sum() / num_masks\n\n\ndice_loss_jit = torch.jit.script(\n    dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef sigmoid_ce_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n    ):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n\n    return loss.mean(1).sum() / num_masks\n\n\nsigmoid_ce_loss_jit = torch.jit.script(\n    sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef calculate_uncertainty(logits):\n    \"\"\"\n    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the\n        foreground class in `classes`.\n    Args:\n        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or\n            class-agnostic, where R is the total number of predicted masks in all images and C is\n            the number of foreground classes. The values are logits.\n    Returns:\n        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with\n            the most uncertain locations having the highest uncertainty score.\n    \"\"\"\n    assert logits.shape[1] == 1\n    gt_class_logits = logits.clone()\n    return -(torch.abs(gt_class_logits))\n\n\nclass SetCriterion(nn.Module):\n    \"\"\"This class computes the loss for DETR.\n    The process happens in two steps:\n        1) we compute hungarian assignment between ground truth boxes and the outputs of the model\n        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)\n    \"\"\"\n\n    def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses,\n                 num_points, oversample_ratio, importance_sample_ratio, grounding_weight, dn=\"no\",dn_losses=[], panoptic_on=False, semantic_ce_loss=False):\n        \"\"\"Create the criterion.\n        Parameters:\n            num_classes: number of object categories, omitting the special no-object category\n            matcher: module able to compute a matching between targets and proposals\n            weight_dict: dict containing as key the names of the losses and as values their relative weight.\n            eos_coef: relative classification weight applied to the no-object category\n            losses: list of all the losses to be applied. See get_loss for list of available losses.\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.matcher = matcher\n        self.weight_dict = weight_dict\n        self.eos_coef = eos_coef\n        self.top_x_layers = top_x_layers\n        self.losses = losses\n        self.dn = dn\n        self.dn_losses = dn_losses\n        empty_weight = torch.ones(self.num_classes + 1)\n        empty_weight[-1] = self.eos_coef\n        self.register_buffer(\"empty_weight\", empty_weight)\n\n        # pointwise mask loss parameters\n        self.num_points = num_points\n        self.oversample_ratio = oversample_ratio\n        self.importance_sample_ratio = importance_sample_ratio\n        self.focal_alpha = 0.25\n\n        self.panoptic_on = panoptic_on\n        self.semantic_ce_loss = semantic_ce_loss\n        self.grounding_weight = grounding_weight\n        self.conversation=False\n\n    def loss_labels_ce(self, outputs, targets, indices, num_masks, layer_id=None, extra=None):\n        \"\"\"Classification loss (NLL)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        if layer_id > self.top_x_layers['mask']:\n            return {\"loss_mask_cls_0\": 0}\n        assert \"pred_logits\" in outputs\n        if indices is None or len(targets) == 0:\n            loss_ce = outputs['pred_logits'].sum() * 0.0\n            losses = {\"loss_mask_cls_0\": loss_ce}\n            return losses\n\n        src_logits = outputs[\"pred_logits\"].type(self.empty_weight.dtype)\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[\"labels\"][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(\n            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device\n        )\n        target_classes[idx] = target_classes_o\n\n        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)\n        losses = {\"loss_mask_cls_0\": loss_ce}\n        return losses\n\n    def loss_labels_masked(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None):\n        \"\"\"Classification loss (Binary focal loss)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        if layer_id > self.top_x_layers['mask']:\n            return {\"loss_mask_cls_0\": 0}\n        assert 'pred_logits' in outputs\n        if indices is None or len(targets) == 0:\n            loss_ce = outputs['pred_logits'].sum() * 0.0\n            losses = {\"loss_mask_cls_0\": loss_ce}\n            return losses\n\n        src_logits = outputs['pred_logits']\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[\"labels\"][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(src_logits.shape[:2], self.num_classes,\n                                    dtype=torch.int64, device=src_logits.device)\n        target_classes[idx] = target_classes_o\n\n        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],\n                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)\n        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)\n\n        target_classes_onehot = target_classes_onehot[:,:,:-1]\n        loss_ce = sigmoid_focal_loss(src_logits[idx], target_classes_onehot[idx], num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]\n        losses = {'loss_mask_cls_0': loss_ce}\n\n        return losses\n\n    def loss_labels(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None):\n        \"\"\"Classification loss (Binary focal loss)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        if layer_id > self.top_x_layers['mask']:\n            return {\"loss_mask_cls_0\": 0}\n        assert 'pred_logits' in outputs\n        if indices is None or len(targets) == 0:\n            loss_ce = outputs['pred_logits'].sum() * 0.0\n            losses = {\"loss_mask_cls_0\": loss_ce}\n            return losses\n\n        src_logits = outputs['pred_logits']\n\n        idx = self._get_src_permutation_idx(indices)\n        # target_classes_o = torch.cat([t[\"labels\"][J] for t, (_, J) in zip(targets, indices)])\n        # target_classes = torch.full(src_logits.shape[:2], self.num_classes,\n        #                             dtype=torch.int64, device=src_logits.device)\n        # target_classes[idx] = target_classes_o\n\n        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]],\n                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)\n        for batch_id,indices_ in enumerate(indices):\n            for src,tgt in zip(*indices_):\n               gt_lbs=targets[batch_id]['labels'][tgt]\n               target_classes_onehot[batch_id,src,gt_lbs]=1\n        # target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)\n\n        # target_classes_onehot = target_classes_onehot[:,:,:-1]\n        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]\n        losses = {'loss_mask_cls_0': loss_ce}\n\n        return losses\n\n    def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None):\n        \"\"\"Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss\n           targets dicts must contain the key \"boxes\" containing a tensor of dim [nb_target_boxes, 4]\n           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.\n        \"\"\"\n        if layer_id >= self.top_x_layers['box']:\n            return {\"loss_bbox_0\": 0, \"loss_giou_0\": 0}\n        assert 'pred_boxes' in outputs\n        if indices is None or len(targets) == 0:\n            loss = outputs['pred_boxes'].sum() * 0.0\n            losses = {\"loss_bbox_0\": loss, \"loss_giou_0\": loss}\n            return losses\n\n        idx = self._get_src_permutation_idx(indices)\n        src_boxes = outputs['pred_boxes'][idx]\n        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n\n        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')\n        losses = {}\n        losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes\n\n        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(\n            box_ops.box_cxcywh_to_xyxy(src_boxes),\n            box_ops.box_cxcywh_to_xyxy(target_boxes)))\n        losses['loss_giou_0'] = loss_giou.sum() / num_boxes\n\n        return losses\n\n    def loss_boxes_panoptic(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None):\n        \"\"\"Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss\n           targets dicts must contain the key \"boxes\" containing a tensor of dim [nb_target_boxes, 4]\n           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.\n        \"\"\"\n        if layer_id >= self.top_x_layers['box']:\n            return {\"loss_bbox_0\": 0, \"loss_giou_0\": 0}\n        assert 'pred_boxes' in outputs\n        if indices is None or len(targets) == 0:\n            loss = outputs['pred_boxes'].sum() * 0.0\n            losses = {\"loss_bbox_0\": loss, \"loss_giou_0\": loss}\n            return losses\n\n        idx = self._get_src_permutation_idx(indices)\n        src_boxes = outputs['pred_boxes'][idx]\n        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n        target_labels = torch.cat([t['labels'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n        isthing=target_labels<80\n        target_boxes=target_boxes[isthing]\n        src_boxes=src_boxes[isthing]\n\n        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')\n        losses = {}\n        losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes\n\n        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(\n            box_ops.box_cxcywh_to_xyxy(src_boxes),\n            box_ops.box_cxcywh_to_xyxy(target_boxes)))\n        losses['loss_giou_0'] = loss_giou.sum() / num_boxes\n        return losses\n\n    def loss_masks(self, outputs, targets, indices, num_masks, layer_id=None, extra=None):\n        \"\"\"Compute the losses related to the masks: the focal loss and the dice loss.\n        targets dicts must contain the key \"masks\" containing a tensor of dim [nb_target_boxes, h, w]\n        \"\"\"\n        if layer_id >= self.top_x_layers['mask']:\n            return {\"loss_mask_bce_0\": 0, \"loss_mask_dice_0\": 0}\n        assert \"pred_masks\" in outputs\n        if indices is None or len(targets) == 0:\n            loss = outputs['pred_masks'].sum() * 0.0\n            losses = {\"loss_mask_bce_0\": loss, \"loss_mask_dice_0\": loss}\n            return losses\n\n        src_idx = self._get_src_permutation_idx(indices)\n        tgt_idx = self._get_tgt_permutation_idx(indices)\n        src_masks = outputs[\"pred_masks\"]\n        src_masks = src_masks[src_idx]\n        masks = [t[\"masks\"] for t in targets]\n        # TODO use valid to mask invalid areas due to padding in loss\n        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()\n        target_masks = target_masks.to(src_masks)\n        target_masks = target_masks[tgt_idx]\n\n        # No need to upsample predictions as we are using normalized coordinates :)\n        # N x 1 x H x W\n        src_masks = src_masks[:, None]\n        target_masks = target_masks[:, None]\n\n        with torch.no_grad():\n            # sample point_coords\n            point_coords = get_uncertain_point_coords_with_randomness(\n                src_masks.float(),\n                lambda logits: calculate_uncertainty(logits.float()),\n                self.num_points,\n                self.oversample_ratio,\n                self.importance_sample_ratio,\n            )\n            # get gt labels\n            point_labels = point_sample(\n                target_masks.float(),\n                point_coords.float(),\n                align_corners=False,\n            ).squeeze(1)\n\n        point_logits = point_sample(\n            src_masks.float(),\n            point_coords.float(),\n            align_corners=False,\n        ).squeeze(1)\n\n        losses = {\n            \"loss_mask_bce_0\": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),\n            \"loss_mask_dice_0\": dice_loss_jit(point_logits, point_labels, num_masks),\n        }\n\n        del src_masks\n        del target_masks\n        return losses\n\n    def prep_for_dn(self,mask_dict):\n        output_known_lbs_bboxes = mask_dict['output_known_lbs_bboxes']\n\n        known_indice = mask_dict['known_indice']\n        scalar,pad_size=mask_dict['scalar'],mask_dict['pad_size']\n        assert pad_size % scalar==0\n        single_pad=pad_size//scalar\n\n        num_tgt = known_indice.numel()\n        return output_known_lbs_bboxes,num_tgt,single_pad,scalar\n\n    def _get_src_permutation_idx(self, indices):\n        # permute predictions following indices\n        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])\n        src_idx = torch.cat([src for (src, _) in indices])\n        return batch_idx, src_idx\n\n    def _get_tgt_permutation_idx(self, indices):\n        # permute targets following indices\n        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])\n        tgt_idx = torch.cat([tgt for (_, tgt) in indices])\n        return batch_idx, tgt_idx\n\n    def get_loss(self, loss, outputs, targets, indices, num_masks=None, layer_id=None, extra=None):\n        loss_map = {\n            'labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels,\n            'labels_dn': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked,\n            'dn_labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked,\n            'masks': self.loss_masks,\n            'boxes': self.loss_boxes_panoptic if self.panoptic_on else self.loss_boxes,\n        }\n        assert loss in loss_map, f\"do you really want to compute {loss} loss?\"\n        return loss_map[loss](outputs, targets, indices, num_masks, layer_id=layer_id, extra=extra)\n\n    def forward(self, outputs, targets, mask_dict=None, extra=None, task='seg'):\n        \"\"\"This performs the loss computation.\n        Parameters:\n             outputs: dict of tensors, see the output specification of the model for the format\n             targets: list of dicts, such that len(targets) == batch_size.\n                      The expected keys in each dict depends on the losses applied, see each loss' doc\n        \"\"\"\n        # TODO: use different matching and loss weight when only detection\n        outputs_without_aux = {k: v for k, v in outputs.items() if k != \"aux_outputs\"}\n        match_cost = [\"cls\", \"box\", \"mask\"]\n        if task == 'det' or task == 'seg_from_teacher':\n            match_cost = [\"cls\", \"box\"]\n        # Retrieve the matching between the outputs of the last layer and the targets\n        if self.dn != \"no\" and mask_dict is not None:\n            output_known_lbs_bboxes,num_tgt,single_pad,scalar = self.prep_for_dn(mask_dict)\n            exc_idx = []\n            for i in range(len(targets)):\n                if len(targets[i]['labels']) > 0:\n                    t = torch.arange(0, len(targets[i]['labels'])).long().cuda()\n                    t = t.unsqueeze(0).repeat(scalar, 1)\n                    tgt_idx = t.flatten()\n                    output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t\n                    output_idx = output_idx.flatten()\n                else:\n                    output_idx = tgt_idx = torch.tensor([]).long().cuda()\n                exc_idx.append((output_idx, tgt_idx))\n        extra=dict()\n        # if task == \"seg\":\n        #     extra['split_pano']={'n_q_th':300} #\n        # else:\n        #     extra['split_pano'] = None\n        indices = self.matcher(outputs_without_aux, targets, match_cost, extra=extra)\n        # Compute the average number of target boxes accross all nodes, for normalization purposes\n        num_masks = sum(len(t[\"labels\"]) for t in targets)\n        num_masks = torch.as_tensor(\n            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device\n        )\n        if is_dist_avail_and_initialized() and not self.conversation:\n            torch.distributed.all_reduce(num_masks)\n            num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()\n        else:\n            num_masks = torch.clamp(num_masks, min=1).item()\n\n        # Compute all the requested losses\n        losses = {}\n        for loss in self.losses:\n            if task == 'det' and loss == 'masks':\n                continue   # not compute mask loss for detection data only\n            losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, layer_id=0, extra=extra))\n\n        if self.dn != \"no\" and mask_dict is not None:\n            l_dict={}\n            for loss in self.dn_losses:\n                if task == 'det' and loss == 'masks':\n                    continue  # not compute mask loss for detection data only\n                if loss == 'labels':\n                    loss='labels_dn'\n                l_dict.update(self.get_loss(loss, output_known_lbs_bboxes, targets, exc_idx, num_masks*scalar, layer_id=0))\n            l_dict = {k + f'_dn': v for k, v in l_dict.items()}\n            losses.update(l_dict)\n        elif self.dn != \"no\":\n            l_dict = dict()\n            l_dict['loss_bbox_0_dn'] = torch.as_tensor(0.).to('cuda')\n            l_dict['loss_giou_0_dn'] = torch.as_tensor(0.).to('cuda')\n            l_dict['loss_mask_cls_0_dn'] = torch.as_tensor(0.).to('cuda')\n            if task != 'det' and 'masks' in self.dn_losses:\n                l_dict['loss_mask_bce_0_dn'] = torch.as_tensor(0.).to('cuda')\n                l_dict['loss_mask_dice_0_dn'] = torch.as_tensor(0.).to('cuda')\n            losses.update(l_dict)\n\n        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.\n        if \"aux_outputs\" in outputs:\n            for i, aux_outputs in enumerate(outputs[\"aux_outputs\"]):\n                indices = self.matcher(aux_outputs, targets, match_cost, extra=extra)\n                for loss in self.losses:\n                    if task == 'det' and loss == 'masks':\n                        continue  # not compute mask loss for detection data only\n                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, layer_id=(i+1), extra=extra)\n                    l_dict = {k.replace('_0', f\"_{i+1}\"): v for k, v in l_dict.items()}\n                    losses.update(l_dict)\n                if 'interm_outputs' in outputs:\n                    start = 0\n                else:\n                    start = 1\n                if i>=start:\n                    if self.dn != \"no\" and mask_dict is not None:\n                        out_=output_known_lbs_bboxes['aux_outputs'][i]\n                        l_dict = {}\n                        for loss in self.dn_losses:\n                            if task == 'det' and loss == 'masks':\n                                continue  # not compute mask loss for detection data only\n                            if loss == 'labels':\n                                loss = 'labels_dn'\n                            l_dict.update(\n                                self.get_loss(loss, out_, targets, exc_idx, num_masks * scalar, layer_id=(i+1), extra=extra))\n                        l_dict = {k.replace('_0', f\"_{i+1}_dn\"): v for k, v in l_dict.items()}\n                        losses.update(l_dict)\n                    elif self.dn != \"no\":\n                        l_dict = dict()\n                        l_dict[f'loss_bbox_{i+1}_dn'] = torch.as_tensor(0.).to('cuda')\n                        l_dict[f'loss_giou_{i+1}_dn'] = torch.as_tensor(0.).to('cuda')\n                        l_dict[f'loss_mask_cls_{i+1}_dn'] = torch.as_tensor(0.).to('cuda')\n                        if self.dn == \"seg\" and task != 'det' and 'masks' in self.dn_losses:\n                            l_dict[f'loss_mask_bce_{i+1}_dn'] = torch.as_tensor(0.).to('cuda')\n                            l_dict[f'loss_mask_dice_{i+1}_dn'] = torch.as_tensor(0.).to('cuda')\n                        losses.update(l_dict)\n\n        # interm_outputs loss\n        if 'interm_outputs' in outputs:\n            interm_outputs = outputs['interm_outputs']\n            indices = self.matcher(interm_outputs, targets, match_cost, extra=extra)\n            full_set = ['labels', 'masks', 'boxes']\n            for loss in list(set(self.losses) and set(full_set)):\n                if task == 'det' and loss == 'masks':\n                    continue  # not compute mask loss for detection data only\n                l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_masks, layer_id=-1, extra=extra)\n                l_dict = {k + f'_interm': v for k, v in l_dict.items()}\n                losses.update(l_dict)\n\n        return losses\n\n    def __repr__(self):\n        head = \"Criterion \" + self.__class__.__name__\n        body = [\n            \"matcher: {}\".format(self.matcher.__repr__(_repr_indent=8)),\n            \"losses: {}\".format(self.losses),\n            \"weight_dict: {}\".format(self.weight_dict),\n            \"num_classes: {}\".format(self.num_classes),\n            \"eos_coef: {}\".format(self.eos_coef),\n            \"num_points: {}\".format(self.num_points),\n            \"oversample_ratio: {}\".format(self.oversample_ratio),\n            \"importance_sample_ratio: {}\".format(self.importance_sample_ratio),\n        ]\n        _repr_indent = 4\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/openseed/modules/matcher.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2023 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li.\n# ------------------------------------------------------------------------\n\n\"\"\"\nModules to compute the matching cost and solve the corresponding LSAP.\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.optimize import linear_sum_assignment\nfrom torch import nn\nfrom torch.cuda.amp import autocast\n\nfrom detectron2.projects.point_rend.point_features import point_sample\nfrom ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy\n\ndef batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * torch.einsum(\"nc,mc->nm\", inputs, targets)\n    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    return loss\n\n\nbatch_dice_loss_jit = torch.jit.script(\n    batch_dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    hw = inputs.shape[1]\n\n    pos = F.binary_cross_entropy_with_logits(\n        inputs, torch.ones_like(inputs), reduction=\"none\"\n    )\n    neg = F.binary_cross_entropy_with_logits(\n        inputs, torch.zeros_like(inputs), reduction=\"none\"\n    )\n\n    loss = torch.einsum(\"nc,mc->nm\", pos, targets) + torch.einsum(\n        \"nc,mc->nm\", neg, (1 - targets)\n    )\n\n    return loss / hw\n\n\nbatch_sigmoid_ce_loss_jit = torch.jit.script(\n    batch_sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\nclass HungarianMatcher(nn.Module):\n    \"\"\"This class computes an assignment between the targets and the predictions of the network\n\n    For efficiency reasons, the targets don't include the no_object. Because of this, in general,\n    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,\n    while the others are un-matched (and thus treated as non-objects).\n    \"\"\"\n\n    def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0,\n                 cost_box: float = 0, cost_giou: float = 0, panoptic_on: bool = False):\n        \"\"\"Creates the matcher\n\n        Params:\n            cost_class: This is the relative weight of the classification error in the matching cost\n            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost\n            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost\n        \"\"\"\n        super().__init__()\n        self.cost_class = cost_class\n        self.cost_mask = cost_mask\n        self.cost_dice = cost_dice\n        self.cost_box = cost_box\n        self.cost_giou = cost_giou\n\n        self.panoptic_on = panoptic_on\n\n        assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, \"all costs cant be 0\"\n\n        self.num_points = num_points\n\n    @torch.no_grad()\n    def memory_efficient_forward(self, outputs, targets, cost=[\"cls\", \"box\", \"mask\"],split_pano=None):\n        \"\"\"More memory-friendly matching. Change cost to compute only certain loss in matching\"\"\"\n        bs, num_queries = outputs[\"pred_logits\"].shape[:2]\n\n        indices = []\n\n        # Iterate through batch size\n        for b in range(bs):\n            out_bbox = outputs[\"pred_boxes\"][b].float()\n            if 'box' in cost:\n                tgt_bbox=targets[b][\"boxes\"]\n                cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)\n                cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))\n            else:\n                cost_bbox = torch.tensor(0).to(out_bbox)\n                cost_giou = torch.tensor(0).to(out_bbox)\n\n            out_prob = outputs[\"pred_logits\"][b].sigmoid().float()  # [num_queries, num_classes]\n            tgt_ids = targets[b][\"labels\"]\n            cost_class=torch.zeros_like(cost_bbox)\n            alpha = 0.25\n            gamma = 2.0\n            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-6).log())\n            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-6).log())\n            for idx, tgt_ids_ in enumerate(tgt_ids):\n                if len(tgt_ids_) == 0:\n                    continue\n                cost_class_tmp = pos_cost_class[:, tgt_ids_] - neg_cost_class[:, tgt_ids_]\n                cost_class_tmp = cost_class_tmp.mean(dim=1, keepdim=False)\n                cost_class[:, idx] = cost_class_tmp\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            # cost_class = -out_prob[:, tgt_ids]\n            if 'mask' in cost:\n                out_mask = outputs[\"pred_masks\"][b].float()   # [num_queries, H_pred, W_pred]\n                # gt masks are already padded when preparing target\n                tgt_mask = targets[b][\"masks\"].to(out_mask).float()\n\n                out_mask = out_mask[:, None]\n                tgt_mask = tgt_mask[:, None]\n                # all masks share the same set of points for efficient matching!\n                point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)\n                # get gt labels\n                tgt_mask = point_sample(\n                    tgt_mask.float(),\n                    point_coords.repeat(tgt_mask.shape[0], 1, 1).float(),\n                    align_corners=False,\n                ).squeeze(1)\n\n                out_mask = point_sample(\n                    out_mask.float(),\n                    point_coords.repeat(out_mask.shape[0], 1, 1).float(),\n                    align_corners=False,\n                ).squeeze(1)\n\n                with autocast(enabled=False):\n                    out_mask = out_mask.float()\n                    tgt_mask = tgt_mask.float()\n                    # Compute the focal loss between masks\n                    cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)\n\n                    # Compute the dice loss betwen masks\n                    cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)\n            else:\n                cost_mask = torch.tensor(0).to(out_bbox)\n                cost_dice = torch.tensor(0).to(out_bbox)\n            \n            # Final cost matrix\n            if self.panoptic_on:\n                isthing = tgt_ids<80\n                cost_bbox[:, ~isthing] = cost_bbox[:, isthing].mean()\n                cost_giou[:, ~isthing] = cost_giou[:, isthing].mean()\n                cost_bbox[cost_bbox.isnan()] = 0.0\n                cost_giou[cost_giou.isnan()] = 0.0\n\n            C = (\n                self.cost_mask * cost_mask\n                + self.cost_class * cost_class\n                + self.cost_dice * cost_dice\n                + self.cost_box*cost_bbox\n                + self.cost_giou*cost_giou\n            )\n            C = C.reshape(num_queries, -1).cpu()\n            # if split_pano is not None:\n            #     n_q_th=split_pano['n_q_th']\n            #     th_mask=tgt_ids<80 # There are 80 COCO thing classes (should be modified when trained with other panoptic datasets)\n            #     C[n_q_th:,th_mask]=1e4\n            #     C[:n_q_th,~th_mask]=1e4\n            indices.append(linear_sum_assignment(C))\n\n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n\n    @torch.no_grad()\n    def forward(self, outputs, targets, cost=[\"cls\", \"box\", \"mask\"], mode='default', extra={}):\n        \"\"\"Performs the matching\n\n        Params:\n            outputs: This is a dict that contains at least these entries:\n                 \"pred_logits\": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits\n                 \"pred_masks\": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks\n\n            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:\n                 \"labels\": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth\n                           objects in the target) containing the class labels\n                 \"masks\": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks\n\n        Returns:\n            A list of size batch_size, containing tuples of (index_i, index_j) where:\n                - index_i is the indices of the selected predictions (in order)\n                - index_j is the indices of the corresponding selected targets (in order)\n            For each batch element, it holds:\n                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)\n        \"\"\"\n        if mode == 'default':\n            if extra is not None:\n                split_pano = extra.get('split_pano', None)\n            else:\n                split_pano=None\n            return self.memory_efficient_forward(outputs, targets, cost,split_pano=split_pano)\n        else:\n            assert False, \"Mode {} is not supported.\".format(mode)\n\n    def __repr__(self, _repr_indent=4):\n        head = \"Matcher \" + self.__class__.__name__\n        body = [\n            \"cost_class: {}\".format(self.cost_class),\n            \"cost_mask: {}\".format(self.cost_mask),\n            \"cost_dice: {}\".format(self.cost_dice),\n        ]\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/openseed/modules/point_features.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch\nfrom torch.nn import functional as F\n\nfrom detectron2.layers import cat, shapes_to_tensor\nfrom detectron2.structures import BitMasks, Boxes\n\n# from ..layers import cat, shapes_to_tensor\n# from ..structures import BitMasks, Boxes\n\n\"\"\"\nShape shorthand in this module:\n\n    N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the\n        number of images for semantic segmenation.\n    R: number of ROIs, combined over all images, in the minibatch\n    P: number of points\n\"\"\"\n\n\ndef point_sample(input, point_coords, **kwargs):\n    \"\"\"\n    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.\n    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside\n    [0, 1] x [0, 1] square.\n\n    Args:\n        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.\n        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains\n        [0, 1] x [0, 1] normalized point coordinates.\n\n    Returns:\n        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains\n            features for points in `point_coords`. The features are obtained via bilinear\n            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.\n    \"\"\"\n    add_dim = False\n    if point_coords.dim() == 3:\n        add_dim = True\n        point_coords = point_coords.unsqueeze(2)\n    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)\n    if add_dim:\n        output = output.squeeze(3)\n    return output\n\n\ndef generate_regular_grid_point_coords(R, side_size, device):\n    \"\"\"\n    Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.\n\n    Args:\n        R (int): The number of grids to sample, one for each region.\n        side_size (int): The side size of the regular grid.\n        device (torch.device): Desired device of returned tensor.\n\n    Returns:\n        (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates\n            for the regular grids.\n    \"\"\"\n    aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)\n    r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)\n    return r.view(1, -1, 2).expand(R, -1, -1)\n\n\ndef get_uncertain_point_coords_with_randomness(\n    coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio\n):\n    \"\"\"\n    Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties\n        are calculated for each point using 'uncertainty_func' function that takes point's logit\n        prediction as input.\n    See PointRend paper for details.\n\n    Args:\n        coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for\n            class-specific or class-agnostic prediction.\n        uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that\n            contains logit predictions for P points and returns their uncertainties as a Tensor of\n            shape (N, 1, P).\n        num_points (int): The number of points P to sample.\n        oversample_ratio (int): Oversampling parameter.\n        importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.\n\n    Returns:\n        point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P\n            sampled points.\n    \"\"\"\n    assert oversample_ratio >= 1\n    assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0\n    num_boxes = coarse_logits.shape[0]\n    num_sampled = int(num_points * oversample_ratio)\n    point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype)\n    point_logits = point_sample(coarse_logits, point_coords, align_corners=False)\n    # It is crucial to calculate uncertainty based on the sampled prediction value for the points.\n    # Calculating uncertainties of the coarse predictions first and sampling them for points leads\n    # to incorrect results.\n    # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between\n    # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.\n    # However, if we calculate uncertainties for the coarse predictions first,\n    # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.\n    point_uncertainties = uncertainty_func(point_logits)\n    num_uncertain_points = int(importance_sample_ratio * num_points)\n    num_random_points = num_points - num_uncertain_points\n    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]\n    shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)\n    idx += shift[:, None]\n    point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(\n        num_boxes, num_uncertain_points, 2\n    )\n    if num_random_points > 0:\n        point_coords = cat(\n            [\n                point_coords,\n                torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),\n            ],\n            dim=1,\n        )\n    return point_coords\n\n\ndef get_uncertain_point_coords_on_grid(uncertainty_map, num_points):\n    \"\"\"\n    Find `num_points` most uncertain points from `uncertainty_map` grid.\n\n    Args:\n        uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty\n            values for a set of points on a regular H x W grid.\n        num_points (int): The number of points P to select.\n\n    Returns:\n        point_indices (Tensor): A tensor of shape (N, P) that contains indices from\n            [0, H x W) of the most uncertain points.\n        point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized\n            coordinates of the most uncertain points from the H x W grid.\n    \"\"\"\n    R, _, H, W = uncertainty_map.shape\n    h_step = 1.0 / float(H)\n    w_step = 1.0 / float(W)\n\n    num_points = min(H * W, num_points)\n    point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]\n    point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)\n    point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step\n    point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step\n    return point_indices, point_coords\n\n\ndef point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):\n    \"\"\"\n    Get features from feature maps in `features_list` that correspond to specific point coordinates\n        inside each bounding box from `boxes`.\n\n    Args:\n        features_list (list[Tensor]): A list of feature map tensors to get features from.\n        feature_scales (list[float]): A list of scales for tensors in `features_list`.\n        boxes (list[Boxes]): A list of I Boxes  objects that contain R_1 + ... + R_I = R boxes all\n            together.\n        point_coords (Tensor): A tensor of shape (R, P, 2) that contains\n            [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.\n\n    Returns:\n        point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled\n            from all features maps in feature_list for P sampled points for all R boxes in `boxes`.\n        point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level\n            coordinates of P points.\n    \"\"\"\n    cat_boxes = Boxes.cat(boxes)\n    num_boxes = [b.tensor.size(0) for b in boxes]\n\n    point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)\n    split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)\n\n    point_features = []\n    for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):\n        point_features_per_image = []\n        for idx_feature, feature_map in enumerate(features_list):\n            h, w = feature_map.shape[-2:]\n            scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature]\n            point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device)\n            point_features_per_image.append(\n                point_sample(\n                    feature_map[idx_img].unsqueeze(0),\n                    point_coords_scaled.unsqueeze(0),\n                    align_corners=False,\n                )\n                .squeeze(0)\n                .transpose(1, 0)\n            )\n        point_features.append(cat(point_features_per_image, dim=1))\n\n    return cat(point_features, dim=0), point_coords_wrt_image\n\n\ndef get_point_coords_wrt_image(boxes_coords, point_coords):\n    \"\"\"\n    Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.\n\n    Args:\n        boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.\n            coordinates.\n        point_coords (Tensor): A tensor of shape (R, P, 2) that contains\n            [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.\n\n    Returns:\n        point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains\n            image-normalized coordinates of P sampled points.\n    \"\"\"\n    with torch.no_grad():\n        point_coords_wrt_image = point_coords.clone()\n        point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (\n            boxes_coords[:, None, 2] - boxes_coords[:, None, 0]\n        )\n        point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (\n            boxes_coords[:, None, 3] - boxes_coords[:, None, 1]\n        )\n        point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]\n        point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]\n    return point_coords_wrt_image\n\n\ndef sample_point_labels(instances, point_coords):\n    \"\"\"\n    Sample point labels from ground truth mask given point_coords.\n\n    Args:\n        instances (list[Instances]): A list of N Instances, where N is the number of images\n            in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is\n            equal to R. The ground-truth gt_masks in each instance will be used to compute labels.\n        points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of\n            instances and P is the number of points for each instance. The coordinates are in\n            the absolute image pixel coordinate space, i.e. [0, H] x [0, W].\n\n    Returns:\n        Tensor: A tensor of shape (R, P) that contains the labels of P sampled points.\n    \"\"\"\n    with torch.no_grad():\n        gt_mask_logits = []\n        point_coords_splits = torch.split(\n            point_coords, [len(instances_per_image) for instances_per_image in instances]\n        )\n        for i, instances_per_image in enumerate(instances):\n            if len(instances_per_image) == 0:\n                continue\n            assert isinstance(\n                instances_per_image.gt_masks, BitMasks\n            ), \"Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'.\"\n\n            gt_bit_masks = instances_per_image.gt_masks.tensor\n            h, w = instances_per_image.gt_masks.image_size\n            scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device)\n            points_coord_grid_sample_format = point_coords_splits[i] / scale\n            gt_mask_logits.append(\n                point_sample(\n                    gt_bit_masks.to(torch.float32).unsqueeze(1),\n                    points_coord_grid_sample_format,\n                    align_corners=False,\n                ).squeeze(1)\n            )\n\n    point_labels = cat(gt_mask_logits)\n    return point_labels\n"
  },
  {
    "path": "llava/model/openseed/modules/position_encoding.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py\n\"\"\"\nVarious positional encodings for the transformer.\n\"\"\"\nimport math\n\nimport torch\nfrom torch import nn\n\n\nclass PositionEmbeddingSine(nn.Module):\n    \"\"\"\n    This is a more standard version of the position embedding, very similar to the one\n    used by the Attention is all you need paper, generalized to work on images.\n    \"\"\"\n\n    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):\n        super().__init__()\n        self.num_pos_feats = num_pos_feats\n        self.temperature = temperature\n        self.normalize = normalize\n        if scale is not None and normalize is False:\n            raise ValueError(\"normalize should be True if scale is passed\")\n        if scale is None:\n            scale = 2 * math.pi\n        self.scale = scale\n\n    def forward(self, x, mask=None):\n        if mask is None:\n            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)\n        not_mask = ~mask\n        y_embed = not_mask.cumsum(1, dtype=x.dtype)\n        x_embed = not_mask.cumsum(2, dtype=x.dtype)\n        if self.normalize:\n            eps = 1e-6\n            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale\n            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale\n\n        dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)\n        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n        pos_x = x_embed[:, :, :, None] / dim_t\n        pos_y = y_embed[:, :, :, None] / dim_t\n        pos_x = torch.stack(\n            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4\n        ).flatten(3)\n        pos_y = torch.stack(\n            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4\n        ).flatten(3)\n        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)\n        return pos\n    \n    def __repr__(self, _repr_indent=4):\n        head = \"Positional encoding \" + self.__class__.__name__\n        body = [\n            \"num_pos_feats: {}\".format(self.num_pos_feats),\n            \"temperature: {}\".format(self.temperature),\n            \"normalize: {}\".format(self.normalize),\n            \"scale: {}\".format(self.scale),\n        ]\n        # _repr_indent = 4\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/openseed/modules/postprocessing.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch\nfrom torch.nn import functional as F\n\nfrom detectron2.structures import Instances, ROIMasks\n\n\n# perhaps should rename to \"resize_instance\"\ndef detector_postprocess(\n    results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5\n):\n    \"\"\"\n    Resize the output instances.\n    The input images are often resized when entering an object detector.\n    As a result, we often need the outputs of the detector in a different\n    resolution from its inputs.\n\n    This function will resize the raw outputs of an R-CNN detector\n    to produce outputs according to the desired output resolution.\n\n    Args:\n        results (Instances): the raw outputs from the detector.\n            `results.image_size` contains the input image resolution the detector sees.\n            This object might be modified in-place.\n        output_height, output_width: the desired output resolution.\n\n    Returns:\n        Instances: the resized output from the model, based on the output resolution\n    \"\"\"\n    if isinstance(output_width, torch.Tensor):\n        # This shape might (but not necessarily) be tensors during tracing.\n        # Converts integer tensors to float temporaries to ensure true\n        # division is performed when computing scale_x and scale_y.\n        output_width_tmp = output_width.float()\n        output_height_tmp = output_height.float()\n        new_size = torch.stack([output_height, output_width])\n    else:\n        new_size = (output_height, output_width)\n        output_width_tmp = output_width\n        output_height_tmp = output_height\n\n    scale_x, scale_y = (\n        output_width_tmp / results.image_size[1],\n        output_height_tmp / results.image_size[0],\n    )\n    results = Instances(new_size, **results.get_fields())\n\n    if results.has(\"pred_boxes\"):\n        output_boxes = results.pred_boxes\n    elif results.has(\"proposal_boxes\"):\n        output_boxes = results.proposal_boxes\n    else:\n        output_boxes = None\n    assert output_boxes is not None, \"Predictions must contain boxes!\"\n\n    output_boxes.scale(scale_x, scale_y)\n    output_boxes.clip(results.image_size)\n\n    results = results[output_boxes.nonempty()]\n\n    if results.has(\"pred_masks\"):\n        if isinstance(results.pred_masks, ROIMasks):\n            roi_masks = results.pred_masks\n        else:\n            # pred_masks is a tensor of shape (N, 1, M, M)\n            roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])\n        results.pred_masks = roi_masks.to_bitmasks(\n            results.pred_boxes, output_height, output_width, mask_threshold\n        ).tensor  # TODO return ROIMasks/BitMask object in the future\n\n    if results.has(\"pred_keypoints\"):\n        results.pred_keypoints[:, :, 0] *= scale_x\n        results.pred_keypoints[:, :, 1] *= scale_y\n\n    return results\n\ndef bbox_postprocess(result, input_size, img_size, output_height, output_width):\n    \"\"\"\n    result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h]\n    \"\"\"\n    if result is None:\n        return None\n    \n    scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device)\n    result = result.sigmoid() * scale\n    x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2\n    h,w = img_size\n\n    x1 = x1.clamp(min=0, max=w)\n    y1 = y1.clamp(min=0, max=h)\n    x2 = x2.clamp(min=0, max=w)\n    y2 = y2.clamp(min=0, max=h)\n\n    box = torch.stack([x1,y1,x2,y2]).permute(1,0)\n    scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device)\n    box = box*scale\n    return box\n\ndef sem_seg_postprocess(result, img_size, output_height, output_width):\n    \"\"\"\n    Return semantic segmentation predictions in the original resolution.\n\n    The input images are often resized when entering semantic segmentor. Moreover, in same\n    cases, they also padded inside segmentor to be divisible by maximum network stride.\n    As a result, we often need the predictions of the segmentor in a different\n    resolution from its inputs.\n\n    Args:\n        result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),\n            where C is the number of classes, and H, W are the height and width of the prediction.\n        img_size (tuple): image size that segmentor is taking as input.\n        output_height, output_width: the desired output resolution.\n\n    Returns:\n        semantic segmentation prediction (Tensor): A tensor of the shape\n            (C, output_height, output_width) that contains per-pixel soft predictions.\n    \"\"\"\n    result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)\n    result = F.interpolate(\n        result, size=(output_height, output_width), mode=\"bicubic\", align_corners=False, antialias=True\n    )[0]\n    return result\n"
  },
  {
    "path": "llava/model/openseed/utils/__init__.py",
    "content": "from .config import *\nfrom .misc import *"
  },
  {
    "path": "llava/model/openseed/utils/box_ops.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nUtilities for bounding box manipulation and GIoU.\n\"\"\"\nimport torch\nfrom torchvision.ops.boxes import box_area\n\n\ndef box_cxcywh_to_xyxy(x):\n    x_c, y_c, w, h = x.unbind(-1)\n    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n         (x_c + 0.5 * w), (y_c + 0.5 * h)]\n    return torch.stack(b, dim=-1)\n\n\ndef box_xyxy_to_cxcywh(x):\n    x0, y0, x1, y1 = x.unbind(-1)\n    b = [(x0 + x1) / 2, (y0 + y1) / 2,\n         (x1 - x0), (y1 - y0)]\n    return torch.stack(b, dim=-1)\n\ndef box_xywh_to_xyxy(x):\n    x0, y0, x1, y1 = x.unbind(-1)\n    b = [x0, y0, (x0 + x1), (y0 + y1)]\n    return torch.stack(b, dim=-1)\n\n\n# modified from torchvision to also return the union\ndef box_iou(boxes1, boxes2):\n    area1 = box_area(boxes1)\n    area2 = box_area(boxes2)\n\n    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]\n    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]\n\n    union = area1[:, None] + area2 - inter\n\n    iou = inter / (union+1e-6)\n    return iou, union\n\n\ndef generalized_box_iou(boxes1, boxes2):\n    \"\"\"\n    Generalized IoU from https://giou.stanford.edu/\n\n    The boxes should be in [x0, y0, x1, y1] format\n\n    Returns a [N, M] pairwise matrix, where N = len(boxes1)\n    and M = len(boxes2)\n    \"\"\"\n    # degenerate boxes gives inf / nan results\n    # so do an early check\n    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()\n    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()\n    iou, union = box_iou(boxes1, boxes2)\n\n    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])\n    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    area = wh[:, :, 0] * wh[:, :, 1]\n\n    return iou - (area - union) / (area+1e-6)\n\n\ndef masks_to_boxes(masks):\n    \"\"\"Compute the bounding boxes around the provided masks\n\n    The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.\n\n    Returns a [N, 4] tensors, with the boxes in xyxy format\n    \"\"\"\n    if masks.numel() == 0:\n        return torch.zeros((0, 4), device=masks.device)\n\n    h, w = masks.shape[-2:]\n\n    y = torch.arange(0, h, dtype=torch.float)\n    x = torch.arange(0, w, dtype=torch.float)\n    y, x = torch.meshgrid(y, x)\n\n    x_mask = (masks * x.unsqueeze(0))\n    x_max = x_mask.flatten(1).max(-1)[0]\n    x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    y_mask = (masks * y.unsqueeze(0))\n    y_max = y_mask.flatten(1).max(-1)[0]\n    y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    return torch.stack([x_min, y_min, x_max, y_max], 1)"
  },
  {
    "path": "llava/model/openseed/utils/config.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Facebook, Inc. and its affiliates.\n\nimport functools\nimport inspect\n\ndef configurable(init_func=None, *, from_config=None):\n    \"\"\"\n    Decorate a function or a class's __init__ method so that it can be called\n    with a :class:`CfgNode` object using a :func:`from_config` function that translates\n    :class:`CfgNode` to arguments.\n\n    Examples:\n    ::\n        # Usage 1: Decorator on __init__:\n        class A:\n            @configurable\n            def __init__(self, a, b=2, c=3):\n                pass\n\n            @classmethod\n            def from_config(cls, cfg):   # 'cfg' must be the first argument\n                # Returns kwargs to be passed to __init__\n                return {\"a\": cfg.A, \"b\": cfg.B}\n\n        a1 = A(a=1, b=2)  # regular construction\n        a2 = A(cfg)       # construct with a cfg\n        a3 = A(cfg, b=3, c=4)  # construct with extra overwrite\n\n        # Usage 2: Decorator on any function. Needs an extra from_config argument:\n        @configurable(from_config=lambda cfg: {\"a: cfg.A, \"b\": cfg.B})\n        def a_func(a, b=2, c=3):\n            pass\n\n        a1 = a_func(a=1, b=2)  # regular call\n        a2 = a_func(cfg)       # call with a cfg\n        a3 = a_func(cfg, b=3, c=4)  # call with extra overwrite\n\n    Args:\n        init_func (callable): a class's ``__init__`` method in usage 1. The\n            class must have a ``from_config`` classmethod which takes `cfg` as\n            the first argument.\n        from_config (callable): the from_config function in usage 2. It must take `cfg`\n            as its first argument.\n    \"\"\"\n\n    if init_func is not None:\n        assert (\n            inspect.isfunction(init_func)\n            and from_config is None\n            and init_func.__name__ == \"__init__\"\n        ), \"Incorrect use of @configurable. Check API documentation for examples.\"\n\n        @functools.wraps(init_func)\n        def wrapped(self, *args, **kwargs):\n            try:\n                from_config_func = type(self).from_config\n            except AttributeError as e:\n                raise AttributeError(\n                    \"Class with @configurable must have a 'from_config' classmethod.\"\n                ) from e\n            if not inspect.ismethod(from_config_func):\n                raise TypeError(\"Class with @configurable must have a 'from_config' classmethod.\")\n\n            if _called_with_cfg(*args, **kwargs):\n                explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)\n                init_func(self, **explicit_args)\n            else:\n                init_func(self, *args, **kwargs)\n\n        return wrapped\n\n    else:\n        if from_config is None:\n            return configurable  # @configurable() is made equivalent to @configurable\n        assert inspect.isfunction(\n            from_config\n        ), \"from_config argument of configurable must be a function!\"\n\n        def wrapper(orig_func):\n            @functools.wraps(orig_func)\n            def wrapped(*args, **kwargs):\n                if _called_with_cfg(*args, **kwargs):\n                    explicit_args = _get_args_from_config(from_config, *args, **kwargs)\n                    return orig_func(**explicit_args)\n                else:\n                    return orig_func(*args, **kwargs)\n\n            wrapped.from_config = from_config\n            return wrapped\n\n        return wrapper\n\ndef _called_with_cfg(*args, **kwargs):\n    \"\"\"\n    Returns:\n        bool: whether the arguments contain CfgNode and should be considered\n            forwarded to from_config.\n    \"\"\"\n    from omegaconf import DictConfig, OmegaConf, ListConfig\n    # from detectron2.config import LazyConfig\n\n    if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))):\n        return True\n    if isinstance(kwargs.pop(\"cfg\", None), (dict)):\n        return True\n    # `from_config`'s first argument is forced to be \"cfg\".\n    # So the above check covers all cases.\n    return False\n\ndef _get_args_from_config(from_config_func, *args, **kwargs):\n    \"\"\"\n    Use `from_config` to obtain explicit arguments.\n\n    Returns:\n        dict: arguments to be used for cls.__init__\n    \"\"\"\n    signature = inspect.signature(from_config_func)\n    if list(signature.parameters.keys())[0] != \"cfg\":\n        if inspect.isfunction(from_config_func):\n            name = from_config_func.__name__\n        else:\n            name = f\"{from_config_func.__self__}.from_config\"\n        raise TypeError(f\"{name} must take 'cfg' as the first argument!\")\n    support_var_arg = any(\n        param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]\n        for param in signature.parameters.values()\n    )\n    if support_var_arg:  # forward all arguments to from_config, if from_config accepts them\n        ret = from_config_func(*args, **kwargs)\n    else:\n        # forward supported arguments to from_config\n        supported_arg_names = set(signature.parameters.keys())\n        extra_kwargs = {}\n        for name in list(kwargs.keys()):\n            if name not in supported_arg_names:\n                extra_kwargs[name] = kwargs.pop(name)\n        ret = from_config_func(*args, **kwargs)\n        # forward the other arguments to __init__\n        ret.update(extra_kwargs)\n    return ret"
  },
  {
    "path": "llava/model/openseed/utils/misc.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py\n\n# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n\"\"\"\nMisc functions, including distributed helpers.\n\nMostly copy-paste from torchvision references.\n\"\"\"\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\nimport torchvision\nfrom torch import Tensor\n\n# from utils.constants import *\n\n\ndef _max_by_axis(the_list):\n    # type: (List[List[int]]) -> List[int]\n    maxes = the_list[0]\n    for sublist in the_list[1:]:\n        for index, item in enumerate(sublist):\n            maxes[index] = max(maxes[index], item)\n    return maxes\n\nclass NestedTensor(object):\n    def __init__(self, tensors, mask: Optional[Tensor]):\n        self.tensors = tensors\n        self.mask = mask\n\n    def to(self, device):\n        # type: (Device) -> NestedTensor # noqa\n        cast_tensor = self.tensors.to(device)\n        mask = self.mask\n        if mask is not None:\n            assert mask is not None\n            cast_mask = mask.to(device)\n        else:\n            cast_mask = None\n        return NestedTensor(cast_tensor, cast_mask)\n\n    def decompose(self):\n        return self.tensors, self.mask\n\n    def __repr__(self):\n        return str(self.tensors)\n\ndef nested_tensor_from_tensor_list(tensor_list: List[Tensor]):\n    # TODO make this more general\n    if tensor_list[0].ndim == 3:\n        if torchvision._is_tracing():\n            # nested_tensor_from_tensor_list() does not export well to ONNX\n            # call _onnx_nested_tensor_from_tensor_list() instead\n            return _onnx_nested_tensor_from_tensor_list(tensor_list)\n\n        # TODO make it support different-sized images\n        max_size = _max_by_axis([list(img.shape) for img in tensor_list])\n        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))\n        batch_shape = [len(tensor_list)] + max_size\n        b, c, h, w = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)\n        for img, pad_img, m in zip(tensor_list, tensor, mask):\n            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n            m[: img.shape[1], : img.shape[2]] = False\n    elif tensor_list[0].ndim == 2:\n        if torchvision._is_tracing():\n            # nested_tensor_from_tensor_list() does not export well to ONNX\n            # call _onnx_nested_tensor_from_tensor_list() instead\n            return _onnx_nested_tensor_from_tensor_list(tensor_list)\n\n        # TODO make it support different-sized images\n        max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])\n        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))\n        batch_shape = [len(tensor_list)] + max_size\n        b, c, l = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.ones((b, l), dtype=torch.bool, device=device)\n        for txt, pad_txt, m in zip(tensor_list, tensor, mask):\n            pad_txt[: txt.shape[0], : txt.shape[1]] = txt\n            m[: txt.shape[1]] = False\n    else:\n        raise ValueError(\"not supported\")\n    return NestedTensor(tensor, mask)\n\ndef _collate_and_pad_divisibility(tensor_list: list, div=32):\n    max_size = []\n    for i in range(tensor_list[0].dim()):\n        max_size_i = torch.max(\n            torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)\n        ).to(torch.int64)\n        max_size.append(max_size_i)\n    max_size = tuple(max_size)\n\n    c,h,w = max_size\n    pad_h = (div - h % div) if h % div != 0 else 0\n    pad_w = (div - w % div) if w % div != 0 else 0\n    max_size = (c,h+pad_h,w+pad_w)\n    \n    # work around for\n    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n    # m[: img.shape[1], :img.shape[2]] = False\n    # which is not yet supported in onnx\n    padded_imgs = []\n    padded_masks = []\n    for img in tensor_list:\n        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]\n        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))\n        padded_imgs.append(padded_img)\n\n        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)\n        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), \"constant\", 1)\n        padded_masks.append(padded_mask.to(torch.bool))\n    \n    return padded_imgs\n\n# _onnx_nested_tensor_from_tensor_list() is an implementation of\n# nested_tensor_from_tensor_list() that is supported by ONNX tracing.\n@torch.jit.unused\ndef _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:\n    max_size = []\n    for i in range(tensor_list[0].dim()):\n        max_size_i = torch.max(\n            torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)\n        ).to(torch.int64)\n        max_size.append(max_size_i)\n    max_size = tuple(max_size)\n\n    # work around for\n    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n    # m[: img.shape[1], :img.shape[2]] = False\n    # which is not yet supported in onnx\n    padded_imgs = []\n    padded_masks = []\n    for img in tensor_list:\n        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]\n        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))\n        padded_imgs.append(padded_img)\n\n        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)\n        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), \"constant\", 1)\n        padded_masks.append(padded_mask.to(torch.bool))\n\n    tensor = torch.stack(padded_imgs)\n    mask = torch.stack(padded_masks)\n\n    return NestedTensor(tensor, mask=mask)\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\n# def get_class_names(name, background=True):\n#     if name is None:\n#         return None\n#     if 'refcoco' in name:\n#         class_names = [\"noun\"]\n#     elif 'coco' in name and 'pan' not in name:\n#         class_names = COCO_INSTANCE_CLASSES + [\"background\"]\n#     elif 'coco' in name:\n#         class_names = COCO_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'ade20k_full' in name:\n#         class_names = ADE20K_847 + [\"background\"]\n#     elif 'ade' in name:\n#         class_names = ADE_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'voc' in name:\n#         class_names = PASCAL_CLASSES + [\"background\"]\n#     elif 'vlp' in name:\n#         class_names = [\"noun\"]\n#     elif 'tsv' in name:\n#         class_names = [\"noun\"]\n#     elif 'phrasecut' in name:\n#         class_names = [\"noun\"]\n#     elif 'openimage' in name:\n#         class_names = [\"noun\"]\n#     elif 'imagenet' in name:\n#         class_names = IMAGENET_CLASSES\n#     elif 'context_459' in name:\n#         class_names = PASCAL_CONTEXT_459 + [\"background\"]\n#     elif 'context_59' in name:\n#         class_names = PASCAL_CONTEXT_59 + [\"background\"]\n#     elif 'context_33' in name:\n#         class_names = PASCAL_CONTEXT_33\n#     elif 'sunrgbd_37' in name:\n#         class_names = SUN_RGBD_37\n#     elif 'scannet_41' in name:\n#         class_names = SCAN_40\n#     elif 'scannet_38' in name:\n#         class_names = SCAN_37\n#     elif 'scannet_21' in name:\n#         class_names = SCAN_20\n#     elif 'object365' in name:\n#         class_names = OBJECT365\n#     elif 'lvis' in name:\n#         class_names = LVIS_CATEGORIES\n#     elif 'seginw' in name:\n#         class_names = SEGINW_CATEGORIES[name.replace('_train', '').replace('_val', '')] + [\"background\"]\n#     elif name == 'cityscapes_fine_sem_seg_val':\n#         class_names = CITYSCAPES\n#     elif name == 'cityscapes_fine_instance_seg_val':\n#         class_names = CITYSCAPES_THING + [\"background\"]\n#     elif name in ['cityscapes_fine_panoptic_val', 'cityscapes_fine_panoptic_train']:\n#         class_names = CITYSCAPES + [\"background\"]\n#     elif name == 'bdd10k_val_sem_seg':\n#         class_names = BDD_SEM\n#     elif name == 'bdd10k_40_panoptic_val':\n#         class_names = BDD_PANO\n#     else:\n#         assert False, \"text dataset name {} is not defined\".format(name)\n#\n#     if background == False and \"background\" in class_names:\n#         class_names.pop(class_names.index(\"background\"))\n#\n#     return class_names\n\n# TODO: add background to \n# def get_class_names(name):\n#     if name is None:\n#         return None\n#     elif 'refcoco' in name:\n#         return [\"background\"]\n#     elif 'coco' in name:\n#         return COCO_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'ade20k_full' in name:\n#         return ADE20K_847 + [\"background\"]\n#     elif 'ade' in name:\n#         return ADE_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'scannet_41' in name:\n#         return SCAN_40\n#     elif 'scannet_21' in name:\n#         return SCAN_20\n#     elif 'sun' in name:\n#         return SUN_RGBD_37\n#     elif name == 'cityscapes_fine_sem_seg_val':\n#         return CITYSCAPES + [\"background\"]\n#     elif name == 'cityscapes_fine_instance_seg_val':\n#         return CITYSCAPES_THING + [\"background\"]\n#     elif name in ['cityscapes_fine_panoptic_val']:\n#         return CITYSCAPES + [\"background\"]\n#     elif name == 'bdd10k_val_sem_seg':\n#         return BDD_SEM + [\"background\"]\n#     elif name == 'bdd10k_40_panoptic_val':\n#         return BDD_PANO + [\"background\"]\n#     elif 'vlp' in name:\n#         return [\"background\"]\n#     else:\n#         assert False, \"text dataset name {} is not defined\".format(name)\n"
  },
  {
    "path": "llava/model/semsam/BaseModel.py",
    "content": "import os\nimport logging\n\nimport torch\nimport torch.nn as nn\n\nfrom utils.model import align_and_update_state_dicts\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseModel(nn.Module):\n    def __init__(self, opt, module: nn.Module):\n        super(BaseModel, self).__init__()\n        self.opt = opt\n        self.model = module\n\n    def forward(self, *inputs, **kwargs):\n        outputs = self.model(*inputs, **kwargs)\n        return outputs\n\n    def save_pretrained(self, save_dir):\n        torch.save(self.model.state_dict(), save_path)\n\n    def from_pretrained(self, load_dir):\n        state_dict = torch.load(load_dir, map_location='cpu')\n        # import pdb;pdb.set_trace()\n        # import pdb;pdb.set_trace()\n        if 'model' in state_dict:\n            state_dict=state_dict['model']\n            state_dict={k[6:]:v for k,v in state_dict.items()}\n\n        # if self.opt['MODEL']['LLAMA'].get('lora_r',0)>0:\n        #     new_sd = dict()\n        #     for k,v in state_dict.items():\n        #         if k.startswith(\"llama.\"):\n        #             if k.startswith(\"llama.base_model.\"):\n        #                 new_sd=state_dict\n        #                 break\n        #             new_sd[k.replace(\"llama.\",\"llama.base_model.model.\")]=v\n        #         else:\n        #             new_sd[k]=v\n        # else:\n        #     new_sd = state_dict\n        new_sd = align_and_update_state_dicts(self.model.state_dict(), state_dict)\n        self.model.load_state_dict(new_sd, strict=False)\n        return self\n"
  },
  {
    "path": "llava/model/semsam/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .architectures import build_model"
  },
  {
    "path": "llava/model/semsam/architectures/__init__.py",
    "content": "from .idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1 import *\nfrom .build import build_model"
  },
  {
    "path": "llava/model/semsam/architectures/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\ndef build_model(config, **kwargs):\n    model_name = config['MODEL']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/semsam/architectures/idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nfrom typing import Tuple\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport transformers\nfrom .registry import register_model\nfrom ..utils import configurable, box_ops\nfrom ..backbone import build_backbone, Backbone\nfrom ..body import build_openseed_head\nfrom ..modules import sem_seg_postprocess, HungarianMatcher\nfrom ..modules import SetCriterionLLM as SetCriterion\nfrom detectron2.structures import Boxes, ImageList, Instances, BitMasks\nfrom detectron2.utils.memory import retry_if_cuda_oom\nfrom detectron2.data import MetadataCatalog\nimport torch.distributed as dist\nimport random\nimport os\nimport torchvision\nfrom PIL import Image\n\n\ndef dice_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        # num_masks,\n):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * (inputs * targets).sum(-1)\n    denominator = inputs.sum(-1) + targets.sum(-1)\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    # only match the lowest loss\n    # loss = loss.view(-1, 3)\n    # loss = loss.min(1)[0]\n    return loss.sum()\n    # return loss\n\n\ndef iou_score_loss(inputs, targets):\n    ce_loss = F.mse_loss(inputs, targets, reduction=\"none\")\n    return ce_loss\n\n\ndice_loss_jit = torch.jit.script(\n    dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef sigmoid_ce_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        # num_masks,\n):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    loss = loss.mean(1)\n    # loss = loss.view(-1, 3).min(1)[0]\n\n    return loss.sum()\n    # return loss\n\n\nsigmoid_ce_loss_jit = torch.jit.script(\n    sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef calculate_uncertainty(logits):\n    \"\"\"\n    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the\n        foreground class in `classes`.\n    Args:\n        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or\n            class-agnostic, where R is the total number of predicted masks in all images and C is\n            the number of foreground classes. The values are logits.\n    Returns:\n        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with\n            the most uncertain locations having the highest uncertainty score.\n    \"\"\"\n    assert logits.shape[1] == 1\n    gt_class_logits = logits.clone()\n    return -(torch.abs(gt_class_logits))\n\n\ndef sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):\n    \"\"\"\n    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n        alpha: (optional) Weighting factor in range (0,1) to balance\n                positive vs negative examples. Default = -1 (no weighting).\n        gamma: Exponent of the modulating factor (1 - p_t) to\n               balance easy vs hard examples.\n    Returns:\n        Loss tensor\n    \"\"\"\n    prob = inputs.sigmoid()\n    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    p_t = prob * targets + (1 - prob) * (1 - targets)\n    loss = ce_loss * ((1 - p_t) ** gamma)\n\n    if alpha >= 0:\n        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)\n        loss = alpha_t * loss\n\n    return loss.sum()\n\n\nclass SemanticSAM(nn.Module):\n    \"\"\"\n    Main class for mask classification semantic segmentation architectures.\n    \"\"\"\n\n    @configurable\n    def __init__(\n            self,\n            *,\n            backbone: Backbone,\n            sem_seg_head: nn.Module,\n            criterion: nn.Module,\n            num_queries: int,\n            object_mask_threshold: float,\n            overlap_threshold: float,\n            metadata,\n            size_divisibility: int,\n            sem_seg_postprocess_before_inference: bool,\n            pixel_mean: Tuple[float],\n            pixel_std: Tuple[float],\n            # inference\n            semantic_on: bool,\n            panoptic_on: bool,\n            instance_on: bool,\n            test_topk_per_image: int,\n            data_loader: str,\n            pano_temp: float,\n            focus_on_box: bool = False,\n            transform_eval: bool = False,\n            semantic_ce_loss: bool = False,\n            train_dataset_name: str,\n            background: bool,\n            coco_on=True,\n            coco_mask_on=True,\n            o365_on=True,\n            ade_on=True,\n            merge_class=False,\n            sam_on: bool = True,\n            pascal_part_on: bool = True,\n            regenerate_point: bool = False,\n            num_mask_tokens: int = 3,\n            interactive_pretrain=False,\n\n            match_loss=True,\n            num_vg=2,\n            vis_out=\"vis/\",\n            coco_old=True,\n            clip_on=False,\n    ):\n        \"\"\"\n        Args:\n            backbone: a backbone module, must follow detectron2's backbone interface\n            sem_seg_head: a module that predicts semantic segmentation from backbone features\n            criterion: a module that defines the loss\n            num_queries: int, number of queries\n            object_mask_threshold: float, threshold to filter query based on classification score\n                for panoptic segmentation inference\n            overlap_threshold: overlap threshold used in general inference for panoptic segmentation\n            metadata: dataset meta, get `thing` and `stuff` category names for panoptic\n                segmentation inference\n            size_divisibility: Some backbones require the input height and width to be divisible by a\n                specific integer. We can use this to override such requirement.\n            sem_seg_postprocess_before_inference: whether to resize the prediction back\n                to original input size before semantic segmentation inference or after.\n                For high-resolution dataset like Mapillary, resizing predictions before\n                inference will cause OOM error.\n            pixel_mean, pixel_std: list or tuple with #channels element, representing\n                the per-channel mean and std to be used to normalize the input image\n            semantic_on: bool, whether to output semantic segmentation prediction\n            instance_on: bool, whether to output instance segmentation prediction\n            panoptic_on: bool, whether to output panoptic segmentation prediction\n            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image\n        \"\"\"\n        super().__init__()\n        self.backbone = backbone\n        self.pano_temp = pano_temp\n        self.sem_seg_head = sem_seg_head\n        self.criterion = criterion\n        self.num_queries = num_queries\n        self.overlap_threshold = overlap_threshold\n        self.object_mask_threshold = object_mask_threshold\n        self.metadata = metadata\n        self.num_vg = num_vg\n        if size_divisibility < 0:\n            size_divisibility = self.backbone.size_divisibility\n        self.size_divisibility = size_divisibility\n        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference\n        self.register_buffer(\"pixel_mean\", torch.Tensor(pixel_mean).view(-1, 1, 1), False)\n        self.register_buffer(\"pixel_std\", torch.Tensor(pixel_std).view(-1, 1, 1), False)\n\n        self.semantic_on = semantic_on\n        self.instance_on = instance_on\n        self.panoptic_on = panoptic_on\n        self.test_topk_per_image = test_topk_per_image\n        self.data_loader = data_loader\n        self.focus_on_box = focus_on_box\n        self.transform_eval = transform_eval\n        self.semantic_ce_loss = semantic_ce_loss\n        self.coco_keys = None\n        self.train_class_names = dict()\n        self.train_dataset_name = train_dataset_name\n        self.coco_mask_on = coco_mask_on\n        self.task_switch = {'coco': coco_on, 'o365': o365_on, 'sam': sam_on, 'pascal_part': pascal_part_on,\n                            \"ade\": ade_on}\n        self.interactive_pretrain = interactive_pretrain\n        self.dbg = False\n        self.positive = 0\n        self.num_objs = 0\n        self.num_hits = 0\n        self.num_refer = 0\n        self.ref_iou = 0.0\n        self.random_iou = 0.0\n        self.match_loss = match_loss\n\n        self.clip_on = clip_on\n        self.num_all_masks = 0.\n        self.coco_old = coco_old\n        self.multimodal_cfg = {'is_multimodal': True, 'image_token_len': 140, 'use_im_start_end': True}\n\n        self.logit_scale = nn.Parameter(torch.ones([]))\n        self.vis_out = vis_out\n        self.obj_projector = nn.Linear(256, 4096)\n        print(\"self.task_switch \", self.task_switch)\n\n        if not self.semantic_on:\n            assert self.sem_seg_postprocess_before_inference\n\n        self.max_num_instance = 100\n        self.num_mask_tokens = num_mask_tokens\n        self.regenerate_point = regenerate_point\n\n    @classmethod\n    def from_config(cls, cfg):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n        \n        deep_supervision = dec_cfg['DEEP_SUPERVISION']\n        no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']\n\n        # loss weights\n        iou_weight = dec_cfg['IOU_WEIGHT']\n        class_weight = dec_cfg['CLASS_WEIGHT']\n        cost_class_weight = dec_cfg['COST_CLASS_WEIGHT']\n        cost_dice_weight = dec_cfg['COST_DICE_WEIGHT']\n        dice_weight = dec_cfg['DICE_WEIGHT']\n        cost_mask_weight = dec_cfg['COST_MASK_WEIGHT']\n        mask_weight = dec_cfg['MASK_WEIGHT']\n        cost_box_weight = dec_cfg['COST_BOX_WEIGHT']\n        box_weight = dec_cfg['BOX_WEIGHT']\n        cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT']\n        giou_weight = dec_cfg['GIOU_WEIGHT']\n\n        refer_weight = dec_cfg['REFER_WEIGHT']\n        fix_backbone = cfg.get('fix_backbone', False)\n\n        # building matcher\n        matcher = HungarianMatcher(\n            cost_class=cost_class_weight,\n            cost_mask=cost_mask_weight,\n            cost_dice=cost_dice_weight,\n            cost_box=cost_box_weight,\n            cost_giou=cost_giou_weight,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n        )\n\n        # MaskDINO losses and weight_dict\n        weight_dict = {\"loss_mask_cls_0\": class_weight}\n        weight_dict.update({\"loss_mask_bce_0\": mask_weight, \"loss_mask_dice_0\": dice_weight})\n        weight_dict.update({\"loss_bbox_0\": box_weight, \"loss_giou_0\": giou_weight})\n        weight_dict.update({\"iou_score_loss_0\": iou_weight})\n        weight_dict.update({\"loss_mask_part_cls_0\": class_weight})\n        # two stage is the query selection scheme\n        if dec_cfg['TWO_STAGE']:\n            interm_weight_dict = {}\n            interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})\n            weight_dict.update(interm_weight_dict)\n        # denoising training\n        dn = dec_cfg['DN']\n        # TODO hack for dn lable loss\n        if dn == \"standard\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items() if k != \"loss_mask\" and k != \"loss_dice\"})\n            dn_losses = [\"dn_labels\", \"boxes\"]\n        elif dn == \"seg\":\n            weight_dict.update({k + f\"_dn\": v for k, v in weight_dict.items()})\n            dn_losses = [\"masks\", \"dn_labels\", \"boxes\"]\n        else:\n            dn_losses = []\n        if deep_supervision:\n            dec_layers = dec_cfg['DEC_LAYERS']\n            aux_weight_dict = {}\n            for i in range(dec_layers):\n                aux_weight_dict.update({k.replace('_0', '_{}'.format(i + 1)): v for k, v in weight_dict.items()})\n            weight_dict.update(aux_weight_dict)\n        if dec_cfg['BOX']:\n            losses = [\"masks\", \"labels\", \"boxes\"]\n        else:\n            losses = [\"masks\", \"labels\", ]\n        if dec_cfg['PART']:\n            losses.append('labels_part')\n        weight_dict.update({'all': 1.0, 'sam': 1.0, 'pas': 1.0})\n\n        # update task switch\n        task_switch = {}\n        task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)})\n        weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0)\n        weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()}\n\n        # building criterion\n        criterion = SetCriterion(\n            enc_cfg['NUM_CLASSES'],\n            matcher=matcher,\n            weight_dict=weight_dict,\n            # top_x_layers=top_x_layers,\n            eos_coef=no_object_weight,\n            losses=losses,\n            num_points=dec_cfg['TRAIN_NUM_POINTS'],\n            oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],\n            importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],\n            # grounding_weight=None,\n            dn=dec_cfg['DN'],\n            dn_losses=dn_losses,\n            panoptic_on=dec_cfg['PANO_BOX_LOSS'],\n            semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][\n                'PANOPTIC_ON'],\n            num_mask_tokens=dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3)\n        )\n\n        # build model\n        extra = {'task_switch': task_switch}\n        backbone = build_backbone(cfg)\n        if fix_backbone:\n            for name, param in backbone.named_parameters():\n                param.requires_grad = False\n        # backbone\n        sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra)\n        if fix_backbone:\n            for name, param in sem_seg_head.named_parameters():\n                param.requires_grad = False\n\n        return {\n            \"backbone\": backbone,\n            \"sem_seg_head\": sem_seg_head,\n            \"criterion\": criterion,\n            \"num_queries\": dec_cfg['NUM_OBJECT_QUERIES'],\n            \"object_mask_threshold\": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],\n            \"overlap_threshold\": dec_cfg['TEST']['OVERLAP_THRESHOLD'],\n            \"metadata\": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),\n            \"size_divisibility\": dec_cfg['SIZE_DIVISIBILITY'],\n            \"sem_seg_postprocess_before_inference\": (\n                    dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']\n                    or dec_cfg['TEST']['PANOPTIC_ON']\n                    or dec_cfg['TEST']['INSTANCE_ON']\n            ),\n            \"pixel_mean\": cfg['INPUT']['PIXEL_MEAN'],\n            \"pixel_std\": cfg['INPUT']['PIXEL_STD'],\n            # inference\n            \"semantic_on\": dec_cfg['TEST']['SEMANTIC_ON'],\n            \"instance_on\": dec_cfg['TEST']['INSTANCE_ON'],\n            \"panoptic_on\": dec_cfg['TEST']['PANOPTIC_ON'],\n            \"test_topk_per_image\": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],\n            \"data_loader\": None,\n            \"focus_on_box\": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'],\n            \"transform_eval\": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'],\n            \"pano_temp\": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'],\n            \"semantic_ce_loss\": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER'][\n                'SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'],\n            \"train_dataset_name\": cfg['DATASETS']['TRAIN'],  # HACK for only two training set\n            \"background\": cfg['MODEL'].get('BACKGROUND', True),\n            \"coco_on\": dec_cfg.get('COCO', True),\n            \"coco_mask_on\": dec_cfg.get('COCO_MASK', True),\n            \"o365_on\": dec_cfg.get('O365', True),\n            \"sam_on\": dec_cfg.get('SAM', True),\n            \"pascal_part_on\": dec_cfg.get('PASCAL', True),\n            \"regenerate_point\": dec_cfg.get('RE_POINT', False),\n            \"num_mask_tokens\": dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3),\n            \"ade_on\": dec_cfg.get('ADE', False),\n            \"interactive_pretrain\": dec_cfg.get('pretrain', False),\n            \"match_loss\": dec_cfg.get('match_loss', True),\n            \"vis_out\": os.path.join(cfg.get('OUTPUT_DIR', 'out'), str(cfg.get('VIS_OUT', 'vis'))),\n            \"coco_old\": cfg.get(\"coco_old\", True),\n            # \"points_per_side_eval\": cfg.get(\"points_per_side_eval\", 30),\n            \"clip_on\": cfg.get(\"clip\", False),\n\n        }\n\n    @property\n    def device(self):\n        return self.pixel_mean.device\n\n    def evaluate_demo(self, batched_inputs, all_whole=None, all_parts=None, mask_features=None,\n                      multi_scale_features=None, return_features=False):\n        assert len(batched_inputs) == 1, \"only support batch size equal to 1\"\n        prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True}\n        images = [x[\"image\"].to(self.device) for x in batched_inputs]\n        images = [(x - self.pixel_mean) / self.pixel_std for x in images]\n        images = ImageList.from_tensors(images, self.size_divisibility)\n        targets = batched_inputs[0]['targets']\n        height = images[0].shape[1]\n        width = images[0].shape[2]\n        padded_h = images.tensor.shape[-2]  # divisable to 32\n        padded_w = images.tensor.shape[-1]\n        targets[0]['points'] = targets[0]['points'] * torch.as_tensor([width, height, width, height], dtype=torch.float,\n                                                                      device=self.device) / torch.as_tensor(\n            [padded_w, padded_h, padded_w, padded_h], dtype=torch.float, device=self.device)\n\n        features = self.backbone(images.tensor)\n        mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(\n            features, None)\n        outputs, mask_dict = self.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets,\n                                                         target_queries=None, target_vlp=None, task='demo',\n                                                         extra=prediction_switch)\n\n        pred_ious = None\n        if 'pred_ious' in outputs.keys():\n            pred_ious = outputs[\"pred_ious\"]\n\n        _, index = pred_ious.view(-1, 3).max(1)\n        index = torch.zeros_like(index)\n        obj_feats = outputs['obj_features'][0].view(-1, self.num_mask_tokens, 256)\n        obj_feats = torch.gather(obj_feats, 1, index[..., None, None].repeat(1, 1, 256))[:, 0]\n        mask_pred_results = outputs[\"pred_masks\"]\n\n        # upsample masks\n        mask_pred_results = F.interpolate(\n            mask_pred_results.float(),\n            size=(images.tensor.shape[-2], images.tensor.shape[-1]),\n            mode=\"bilinear\",\n            align_corners=False,\n        )\n        mask_pred_results = mask_pred_results.view(-1, self.num_mask_tokens, images.tensor.shape[-2],\n                                                   images.tensor.shape[-1])\n        mask_pred_results = torch.gather(mask_pred_results, 1,\n                                         index[..., None, None, None].repeat(1, 1, images.tensor.shape[-2],\n                                                                             images.tensor.shape[-1]))\n        pred_masks = mask_pred_results[:, 0]\n\n        image_size = images.image_sizes[0]\n\n        height = image_size[0]\n        width = image_size[1]\n        if self.sem_seg_postprocess_before_inference:\n            pred_masks = retry_if_cuda_oom(sem_seg_postprocess)(\n                pred_masks, image_size, height, width\n            )\n        return pred_masks, pred_ious, self.obj_projector(obj_feats)\n\n    def forward(self, batched_inputs, inference_task='seg',detach=False):\n\n        if self.training:\n            obj_feats,inter_losses= self.forward_det_pretrain(batched_inputs)\n            for k in list(inter_losses.keys()):\n                if k in self.criterion.weight_dict:\n                    inter_losses[k] *= self.criterion.weight_dict[k]\n                    # losses[k] *= scale\n                else:\n                    # remove this loss if not specified in `weight_dict`\n                    inter_losses.pop(k)\n            new_losses = {}\n            for key, value in inter_losses.items():\n                new_losses['inter' + '.' + str(key)] = inter_losses[key]\n            if detach:\n                return [self.obj_projector(feat.detach()) for feat in obj_feats],new_losses\n            else:\n                return [self.obj_projector(feat)[0] for feat in obj_feats],new_losses\n        else:\n            return self.evaluate_demo(batched_inputs)\n\n    def forward_det_pretrain(self, batched_inputs, task='seg',\n                             prediction_switch={'part': True, 'whole': True, 'seg': True, 'det': True}, dataname='coco',\n                             semantic=False):\n        images = [x[\"image\"].to(self.device) for x in batched_inputs]\n        images = [(x - self.pixel_mean) / self.pixel_std for x in images]\n        images = ImageList.from_tensors(images, self.size_divisibility)\n\n        features = self.backbone(images.tensor)\n        mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(\n            features, None)\n        if self.clip_on:\n            image = \\\n            preprocess.preprocess(Image.fromarray(batched_inputs[0]['image_ori']), return_tensors='pt')['pixel_values'][\n                0]\n        prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True}\n\n        # self.criterion.num_classes = len(train_class_names)\n        train_class_names_part = None\n        # if prediction_switch['part']:\n        #     train_class_names_part = self.train_class_names[dataname + '_part']\n        #     self.criterion.num_classes_part = len(train_class_names)\n        if \"instances\" in batched_inputs[0]:\n            gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n            targets = self.prepare_targets_sam(gt_instances, images, prediction_switch=prediction_switch)\n        else:\n            targets = None\n            print(\"empty targets\", targets, task)\n\n        if prediction_switch['whole']:\n            prediction_switch['whole'] = False\n        if prediction_switch['part']:\n            prediction_switch['part'] = False\n\n        # tgt_temp=[]\n        obj_features_ls=[]\n        losses_total=None\n        num_masks=0\n        for i,tgt in enumerate(targets):\n            tgt_temp=[tgt]\n            outputs_gt, mask_dict = self.sem_seg_head.predictor([feat[i:i+1] for feat in multi_scale_features], mask_features[i:i+1], None,\n                                                            targets=tgt_temp,\n                                                            target_queries=None, target_vlp=None, task='seg',\n                                                            extra=prediction_switch)\n            self.criterion.index = torch.zeros_like(batched_inputs[i]['instances'].gt_classes).to(outputs_gt['obj_features'].device)\n\n            losses, index = self.criterion(outputs_gt, tgt_temp, mask_dict, task='seg', extra=prediction_switch,\n                                           return_idx=True)\n            index=self.criterion.index\n            bs, n, h, w = outputs_gt[\"pred_masks\"].shape\n            obj_features = outputs_gt['obj_features'].view(bs, -1, self.num_mask_tokens, 256)\n            obj_features = torch.gather(obj_features, 2, index[None][..., None, None].repeat(1, 1, 1, 256))[:,:, 0]\n            # mask_pred_results = outputs_gt[\"pred_masks\"][0].view(-1, self.num_mask_tokens, h, w)\n\n            obj_features_ls.append(obj_features)\n            num_masks+=losses['num_masks']\n            if losses_total is None:\n                losses_total=dict()\n                for key in losses.keys():\n                    if key != 'num_masks':\n                        losses_total[key] = losses[key] * losses['num_masks']\n            else:\n                for key in losses.keys():\n                    if key != 'num_masks':\n                        losses_total[key]+=losses[key]*losses['num_masks']\n        for key in losses_total.keys():\n            if key != 'num_masks':\n                losses_total[key] = losses_total[key]/num_masks\n        return obj_features_ls,losses_total\n\n    def prepare_targets_sam(self, targets, images, prediction_switch, task='seg', min_box=0.33, max_box=1.0):\n        h_pad, w_pad = images.tensor.shape[-2:]\n        new_targets = []\n\n        # box_start = random.randint(int((self.max_num_instance - 1)/2), self.max_num_instance - 1)  # box based interactive after this number; about 1/4\n        # if random.random()<0.5 and self.dbg:\n        #     # import pdb;pdb.set_trace()\n        #     targets[0]=targets[0][:0]\n        if not self.dbg:\n            self.empty_targets = targets\n            self.dbg = True\n        if len(targets[0]) == 0:\n            empty = True\n            targets = self.empty_targets\n        else:\n            empty = False\n        for targets_per_image in targets:\n            gt_boxes = targets_per_image.gt_boxes if torch.is_tensor(\n                targets_per_image.gt_boxes) else targets_per_image.gt_boxes.tensor\n            # empty=len(gt_boxes)==0\n            assert len(gt_boxes)>0\n            self.max_num_instance =  len(gt_boxes)\n            box_start = random.randint(int(self.max_num_instance * min_box), int(self.max_num_instance * max_box))\n            # pad gt\n            h, w = targets_per_image.image_size\n\n            image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)\n            gt_masks = targets_per_image.gt_masks if torch.is_tensor(\n                targets_per_image.gt_masks) else targets_per_image.gt_masks.tensor\n\n            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)\n            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks\n            num_mask = targets_per_image.gt_classes.shape[0]\n\n            index = torch.arange(num_mask)\n\n            if self.max_num_instance > num_mask:\n                rep = 0 if num_mask == 0 else int(self.max_num_instance / num_mask) + 1\n                index = index.repeat(rep)\n            index = index[:self.max_num_instance]\n\n            # if self.regenerate_point and box_start > 0:\n            point_coords = []\n            if box_start > 0:\n                for i in range(box_start):\n                    mask = gt_masks[index[i]].clone()\n                    candidate_indices = mask.nonzero()\n                    if len(candidate_indices) == 0:\n                        print('wrong')\n                        selected_point = torch.tensor([0, 0]).cuda()\n                    else:\n                        selected_index = random.randint(0, len(candidate_indices) - 1)\n                        selected_point = candidate_indices[selected_index].flip(0)\n                    selected_point = torch.cat([selected_point - 3, selected_point + 3], 0)\n                    point_coords.append(selected_point)\n                point_coords = torch.stack(point_coords).to('cuda')\n            # else:\n                # point_coords = targets_per_image.point_coords[index[:box_start]]\n            # point_coords = targets_per_image.gt_boxes.tensor[index[:box_start]]\n            new_target = {\n                \"ori_mask_num\": len(targets_per_image.gt_classes),\n                \"labels\": targets_per_image.gt_classes[index] if prediction_switch['whole'] else None,\n                \"masks\": padded_masks[index],\n                \"boxes\": box_ops.box_xyxy_to_cxcywh(gt_boxes[index]) / image_size_xyxy,\n                \"points\": box_ops.box_xyxy_to_cxcywh(point_coords) / image_size_xyxy if len(point_coords) > 0 else None,\n                # \"pb\":torch.randint(2,(min(self.max_num_instance,len(targets_per_image.gt_classes)),),device=gt_masks.device),\n                \"pb\": torch.cat([torch.zeros(box_start), torch.ones(self.max_num_instance - box_start)], 0),\n                \"gt_whole_classes\": targets_per_image.gt_whole_classes[index] if targets_per_image.has(\n                    'gt_whole_classes') and prediction_switch['whole'] else None,\n                \"gt_part_classes\": targets_per_image.gt_part_classes[index] if targets_per_image.has(\n                    'gt_part_classes') and prediction_switch['part'] else None,\n            }\n            # handle coco data format\n            if prediction_switch['whole'] and not prediction_switch['part']:\n                new_target['gt_whole_classes'] = targets_per_image.gt_classes[index]\n            if new_target[\"points\"] is not None:\n                new_target[\"boxes_dn\"] = torch.cat([new_target[\"points\"], new_target[\"boxes\"][box_start:]], 0)\n            else:\n                new_target[\"boxes_dn\"] = new_target[\"boxes\"][box_start:]\n\n            new_target['box_start'] = box_start\n            new_target['empty'] = empty\n            new_targets.append(new_target)\n\n        return new_targets\n\n\n\n@register_model\ndef get_segmentation_model(cfg, **kwargs):\n    return SemanticSAM(cfg)"
  },
  {
    "path": "llava/model/semsam/architectures/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_model(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/semsam/backbone/__init__.py",
    "content": "from .build import build_backbone\n\nfrom .focal import *\nfrom .focal_dw import *\nfrom .swin import *\nfrom .backbone import *"
  },
  {
    "path": "llava/model/semsam/backbone/backbone.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch.nn as nn\n\nfrom detectron2.modeling import ShapeSpec\n\n# from ..layers import ShapeSpec\n\n__all__ = [\"Backbone\"]\n\n\nclass Backbone(nn.Module):\n    \"\"\"\n    Abstract base class for network backbones.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"\n        The `__init__` method of any subclass can specify its own set of arguments.\n        \"\"\"\n        super().__init__()\n\n    def forward(self):\n        \"\"\"\n        Subclasses must override this method, but adhere to the same return type.\n\n        Returns:\n            dict[str->Tensor]: mapping from feature name (e.g., \"res2\") to tensor\n        \"\"\"\n        pass\n\n    @property\n    def size_divisibility(self) -> int:\n        \"\"\"\n        Some backbones require the input height and width to be divisible by a\n        specific integer. This is typically true for encoder / decoder type networks\n        with lateral connection (e.g., FPN) for which feature maps need to match\n        dimension in the \"bottom up\" and \"top down\" paths. Set to 0 if no specific\n        input size divisibility is required.\n        \"\"\"\n        return 0\n\n    def output_shape(self):\n        \"\"\"\n        Returns:\n            dict[str->ShapeSpec]\n        \"\"\"\n        # this is a backward-compatible default\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n"
  },
  {
    "path": "llava/model/semsam/backbone/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\nfrom .backbone import *\n\ndef build_backbone(config, **kwargs):\n    model_name = config['MODEL']['BACKBONE']['NAME']\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/semsam/backbone/focal.py",
    "content": "# --------------------------------------------------------\n# FocalNet for Semantic Segmentation\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Jianwei Yang\n# --------------------------------------------------------\nimport math\nimport time\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.utils.file_io import PathManager\nfrom detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\nclass Mlp(nn.Module):\n    \"\"\" Multilayer perceptron.\"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass FocalModulation(nn.Module):\n    \"\"\" Focal Modulation\n\n    Args:\n        dim (int): Number of input channels.\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        focal_factor (int, default=2): Step to increase the focal window\n        use_postln (bool, default=False): Whether use post-modulation layernorm\n    \"\"\"\n\n    def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):\n\n        super().__init__()\n        self.dim = dim\n\n        # specific args for focalv3\n        self.focal_level = focal_level\n        self.focal_window = focal_window\n        self.focal_factor = focal_factor\n        self.use_postln_in_modulation = use_postln_in_modulation\n        self.scaling_modulator = scaling_modulator\n\n        self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)\n        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)\n\n        self.act = nn.GELU()\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.focal_layers = nn.ModuleList()\n\n        if self.use_postln_in_modulation:\n            self.ln = nn.LayerNorm(dim)\n\n        for k in range(self.focal_level):\n            kernel_size = self.focal_factor*k + self.focal_window\n            self.focal_layers.append(\n                nn.Sequential(\n                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, \n                        padding=kernel_size//2, bias=False),\n                    nn.GELU(),\n                    )\n                )\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: input features with shape of (B, H, W, C)\n        \"\"\"\n        B, nH, nW, C = x.shape\n        x = self.f(x)\n        x = x.permute(0, 3, 1, 2).contiguous()\n        q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)\n        \n        ctx_all = 0\n        for l in range(self.focal_level):                     \n            ctx = self.focal_layers[l](ctx)\n            ctx_all = ctx_all + ctx*gates[:, l:l+1]\n        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))\n        ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]\n\n        if self.scaling_modulator:\n            ctx_all = ctx_all / (self.focal_level + 1)\n\n        x_out = q * self.h(ctx_all)\n        x_out = x_out.permute(0, 2, 3, 1).contiguous()\n        if self.use_postln_in_modulation:\n            x_out = self.ln(x_out)            \n        x_out = self.proj(x_out)\n        x_out = self.proj_drop(x_out)\n        return x_out\n\nclass FocalModulationBlock(nn.Module):\n    \"\"\" Focal Modulation Block.\n\n    Args:\n        dim (int): Number of input channels.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        focal_level (int): number of focal levels\n        focal_window (int): focal kernel size at level 1\n    \"\"\"\n\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., \n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 focal_level=2, focal_window=9, \n                 use_postln=False, use_postln_in_modulation=False,\n                 scaling_modulator=False, \n                 use_layerscale=False, \n                 layerscale_value=1e-4):\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.focal_window = focal_window\n        self.focal_level = focal_level\n        self.use_postln = use_postln\n        self.use_layerscale = use_layerscale\n\n        self.norm1 = norm_layer(dim)\n        self.modulation = FocalModulation(\n            dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator\n        )            \n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.H = None\n        self.W = None\n\n        self.gamma_1 = 1.0\n        self.gamma_2 = 1.0\n        if self.use_layerscale:\n            self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        if not self.use_postln:\n            x = self.norm1(x)\n        x = x.view(B, H, W, C)\n        \n        # FM\n        x = self.modulation(x).view(B, H * W, C)\n        if self.use_postln:\n            x = self.norm1(x)\n\n        # FFN\n        x = shortcut + self.drop_path(self.gamma_1 * x)\n\n        if self.use_postln:\n            x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))\n        else:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n\n        return x\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic focal modulation layer for one stage.\n\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 mlp_ratio=4.,\n                 drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 focal_window=9, \n                 focal_level=2, \n                 use_conv_embed=False,     \n                 use_postln=False,          \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False,                   \n                 use_checkpoint=False\n        ):\n        super().__init__()\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            FocalModulationBlock(\n                dim=dim,\n                mlp_ratio=mlp_ratio,\n                drop=drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                focal_window=focal_window, \n                focal_level=focal_level, \n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation, \n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                patch_size=2,\n                in_chans=dim, embed_dim=2*dim, \n                use_conv_embed=use_conv_embed, \n                norm_layer=norm_layer, \n                is_stem=False\n            )\n\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)\n            x_down = self.downsample(x_reshaped)   \n            x_down = x_down.flatten(2).transpose(1, 2)            \n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False\n        is_stem (bool): Is the stem block or not. \n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        if use_conv_embed:\n            # if we choose to use conv embedding, then we treat the stem and non-stem differently\n            if is_stem:\n                kernel_size = 7; padding = 2; stride = 4\n            else:\n                kernel_size = 3; padding = 1; stride = 2\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)                    \n        else:\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass FocalNet(nn.Module):\n    \"\"\" FocalNet backbone.\n\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop_rate (float): Dropout rate.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        focal_levels (Sequence[int]): Number of focal levels at four stages\n        focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=1600,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dim=96,\n                 depths=[2, 2, 6, 2],\n                 mlp_ratio=4.,\n                 drop_rate=0.,\n                 drop_path_rate=0.2,\n                 norm_layer=nn.LayerNorm,\n                 patch_norm=True,\n                 out_indices=[0, 1, 2, 3],\n                 frozen_stages=-1,\n                 focal_levels=[2,2,2,2], \n                 focal_windows=[9,9,9,9],\n                 use_conv_embed=False, \n                 use_postln=False, \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False, \n                 use_checkpoint=False, \n        ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None, \n            use_conv_embed=use_conv_embed, is_stem=True)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                mlp_ratio=mlp_ratio,\n                drop=drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,\n                focal_window=focal_windows[i_layer], \n                focal_level=focal_levels[i_layer], \n                use_conv_embed=use_conv_embed,\n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation,\n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f'norm{i_layer}'\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n\n        missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]\n        logger.info(f'=> Missed keys {missed_dict}')\n        unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]\n        logger.info(f'=> Unexpected keys {unexpected_dict}')\n\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        \n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                and 'relative_position_index' not in k\n                and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():\n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    fsize1 = table_pretrained.shape[2]\n                    fsize2 = table_current.shape[2]\n\n                    # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv\n                    if fsize1 < fsize2:\n                        table_pretrained_resized = torch.zeros(table_current.shape)\n                        table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained\n                        v = table_pretrained_resized\n                    elif fsize1 > fsize2:\n                        table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]\n                        v = table_pretrained_resized\n\n\n                if (\"modulation.f\" in k or \"pre_conv\" in k): \n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    if table_pretrained.shape != table_current.shape:\n                        if len(table_pretrained.shape) == 2:\n                            dim = table_pretrained.shape[1]\n                            assert table_current.shape[1] == dim\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError\n                        elif len(table_pretrained.shape) == 1:\n                            dim = table_pretrained.shape[0]\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:dim] = table_pretrained[:dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError    \n\n                need_init_state_dict[k] = v\n        \n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        tic = time.time()\n        x = self.patch_embed(x)\n        Wh, Ww = x.size(2), x.size(3)\n\n        x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n                \n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n\n        toc = time.time()\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(FocalNet, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2FocalNet(FocalNet, Backbone):\n    def __init__(self, cfg, input_shape):\n\n        pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']\n        patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']\n        in_chans = 3\n        embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']\n        depths = cfg['BACKBONE']['FOCAL']['DEPTHS']\n        mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']\n        drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']\n        drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']\n        norm_layer = nn.LayerNorm\n        patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']\n        use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']\n        out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']\n        scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)\n\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            mlp_ratio,\n            drop_rate,\n            drop_path_rate,\n            norm_layer,\n            patch_norm,\n            out_indices,\n            focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],\n            focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],   \n            use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],    \n            use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],       \n            use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], \n            scaling_modulator=scaling_modulator,\n            use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], \n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n@register_backbone\ndef get_focal_backbone(cfg):\n    focal = D2FocalNet(cfg['MODEL'], 224)    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        logger.info(f'=> init from {filename}')\n        with PathManager.open(filename, \"rb\") as f:\n            ckpt = torch.load(f)['model']\n        focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return focal"
  },
  {
    "path": "llava/model/semsam/backbone/focal_dw.py",
    "content": "# --------------------------------------------------------\n# FocalNet for Semantic Segmentation\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Jianwei Yang\n# --------------------------------------------------------\nimport math\nimport time\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.utils.file_io import PathManager\nfrom detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\nclass Mlp(nn.Module):\n    \"\"\" Multilayer perceptron.\"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass FocalModulation(nn.Module):\n    \"\"\" Focal Modulation\n\n    Args:\n        dim (int): Number of input channels.\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        focal_factor (int, default=2): Step to increase the focal window\n        use_postln (bool, default=False): Whether use post-modulation layernorm\n    \"\"\"\n\n    def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):\n\n        super().__init__()\n        self.dim = dim\n\n        # specific args for focalv3\n        self.focal_level = focal_level\n        self.focal_window = focal_window\n        self.focal_factor = focal_factor\n        self.use_postln_in_modulation = use_postln_in_modulation\n        self.scaling_modulator = scaling_modulator\n\n        self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)\n        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)\n\n        self.act = nn.GELU()\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.focal_layers = nn.ModuleList()\n\n        if self.use_postln_in_modulation:\n            self.ln = nn.LayerNorm(dim)\n\n        for k in range(self.focal_level):\n            kernel_size = self.focal_factor*k + self.focal_window\n            self.focal_layers.append(\n                nn.Sequential(\n                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, \n                        padding=kernel_size//2, bias=False),\n                    nn.GELU(),\n                    )\n                )\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: input features with shape of (B, H, W, C)\n        \"\"\"\n        B, nH, nW, C = x.shape\n        x = self.f(x)\n        x = x.permute(0, 3, 1, 2).contiguous()\n        q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)\n        \n        ctx_all = 0\n        for l in range(self.focal_level):                     \n            ctx = self.focal_layers[l](ctx)\n            ctx_all = ctx_all + ctx*gates[:, l:l+1]\n        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))\n        ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]\n\n        if self.scaling_modulator:\n            ctx_all = ctx_all / (self.focal_level + 1)\n\n        x_out = q * self.h(ctx_all)\n        x_out = x_out.permute(0, 2, 3, 1).contiguous()\n        if self.use_postln_in_modulation:\n            x_out = self.ln(x_out)            \n        x_out = self.proj(x_out)\n        x_out = self.proj_drop(x_out)\n        return x_out\n\nclass FocalModulationBlock(nn.Module):\n    \"\"\" Focal Modulation Block.\n\n    Args:\n        dim (int): Number of input channels.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        focal_level (int): number of focal levels\n        focal_window (int): focal kernel size at level 1\n    \"\"\"\n\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., \n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 focal_level=2, focal_window=9, \n                 use_postln=False, use_postln_in_modulation=False,\n                 scaling_modulator=False, \n                 use_layerscale=False, \n                 layerscale_value=1e-4):\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.focal_window = focal_window\n        self.focal_level = focal_level\n        self.use_postln = use_postln\n        self.use_layerscale = use_layerscale\n\n        self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)\n        self.norm1 = norm_layer(dim)\n        self.modulation = FocalModulation(\n            dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator\n        )            \n\n        self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.H = None\n        self.W = None\n\n        self.gamma_1 = 1.0\n        self.gamma_2 = 1.0\n        if self.use_layerscale:\n            self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n        x = x + self.dw1(x)\n        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)\n\n        shortcut = x\n        if not self.use_postln:\n            x = self.norm1(x)\n        x = x.view(B, H, W, C)\n        \n        # FM\n        x = self.modulation(x).view(B, H * W, C)\n        x = shortcut + self.drop_path(self.gamma_1 * x)\n        if self.use_postln:\n            x = self.norm1(x)\n\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n        x = x + self.dw2(x)\n        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)\n\n        if not self.use_postln:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))        \n        else:\n            x = x + self.drop_path(self.gamma_2 * self.mlp(x))\n            x = self.norm2(x)\n\n        return x\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic focal modulation layer for one stage.\n\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        focal_level (int): Number of focal levels\n        focal_window (int): Focal window size at focal level 1\n        use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 mlp_ratio=4.,\n                 drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 focal_window=9, \n                 focal_level=2, \n                 use_conv_embed=False,     \n                 use_postln=False,          \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False,                   \n                 use_checkpoint=False, \n                 use_pre_norm=False, \n        ):\n        super().__init__()\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            FocalModulationBlock(\n                dim=dim,\n                mlp_ratio=mlp_ratio,\n                drop=drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                focal_window=focal_window, \n                focal_level=focal_level, \n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation, \n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                patch_size=2,\n                in_chans=dim, embed_dim=2*dim, \n                use_conv_embed=use_conv_embed, \n                norm_layer=norm_layer, \n                is_stem=False, \n                use_pre_norm=use_pre_norm\n            )\n\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\" Forward function.\n\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)\n            x_down = self.downsample(x_reshaped)   \n            x_down = x_down.flatten(2).transpose(1, 2)            \n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\n# class PatchEmbed(nn.Module):\n#     r\"\"\" Image to Patch Embedding\n\n#     Args:\n#         img_size (int): Image size.  Default: 224.\n#         patch_size (int): Patch token size. Default: 4.\n#         in_chans (int): Number of input image channels. Default: 3.\n#         embed_dim (int): Number of linear projection output channels. Default: 96.\n#         norm_layer (nn.Module, optional): Normalization layer. Default: None\n#     \"\"\"\n\n#     def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, \n#         use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False):\n#         super().__init__()\n#         patch_size = to_2tuple(patch_size)\n#         patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n#         self.img_size = img_size\n#         self.patch_size = patch_size\n#         self.patches_resolution = patches_resolution\n#         self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n#         self.in_chans = in_chans\n#         self.embed_dim = embed_dim\n#         self.use_pre_norm = use_pre_norm\n\n#         if use_conv_embed:\n#             # if we choose to use conv embedding, then we treat the stem and non-stem differently\n#             if is_stem:\n#                 kernel_size = 7; padding = 3; stride = 4\n#             else:\n#                 kernel_size = 3; padding = 1; stride = 2\n#             self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)\n#         else:\n#             self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        \n#         if self.use_pre_norm:\n#             if norm_layer is not None:\n#                 self.norm = norm_layer(in_chans)\n#             else:\n#                 self.norm = None\n#         else:\n#             if norm_layer is not None:\n#                 self.norm = norm_layer(embed_dim)\n#             else:\n#                 self.norm = None\n\n#     def forward(self, x):\n#         B, C, H, W = x.shape\n#         # FIXME look at relaxing size constraints\n#         assert H == self.img_size[0] and W == self.img_size[1], \\\n#             f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        \n#         if self.use_pre_norm:\n#             if self.norm is not None:\n#                 x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n#                 x = self.norm(x).transpose(1, 2).view(B, C, H, W)\n#             x = self.proj(x).flatten(2).transpose(1, 2)\n#         else:\n#             x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n#             if self.norm is not None:\n#                 x = self.norm(x)\n#         return x\n\n#     def flops(self):\n#         Ho, Wo = self.patches_resolution\n#         flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n#         if self.norm is not None:\n#             flops += Ho * Wo * self.embed_dim\n#         return flops\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False\n        is_stem (bool): Is the stem block or not. \n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n        self.use_pre_norm = use_pre_norm\n\n        if use_conv_embed:\n            # if we choose to use conv embedding, then we treat the stem and non-stem differently\n            if is_stem:\n                kernel_size = 7; padding = 3; stride = 4\n            else:\n                kernel_size = 3; padding = 1; stride = 2\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)                    \n        else:\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        if self.use_pre_norm:\n            if norm_layer is not None:\n                self.norm = norm_layer(in_chans)\n            else:\n                self.norm = None       \n        else:\n            if norm_layer is not None:\n                self.norm = norm_layer(embed_dim)\n            else:\n                self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        B, C, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        if self.use_pre_norm:\n            if self.norm is not None:\n                x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n                x = self.norm(x).transpose(1, 2).view(B, C, H, W)\n            x = self.proj(x)\n        else:\n            x = self.proj(x)  # B C Wh Ww\n            if self.norm is not None:\n                Wh, Ww = x.size(2), x.size(3)\n                x = x.flatten(2).transpose(1, 2)\n                x = self.norm(x)\n                x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass FocalNet(nn.Module):\n    \"\"\" FocalNet backbone.\n\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        drop_rate (float): Dropout rate.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        focal_levels (Sequence[int]): Number of focal levels at four stages\n        focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages\n        use_conv_embed (bool): Whether use overlapped convolution for patch embedding\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 pretrain_img_size=1600,\n                 patch_size=4,\n                 in_chans=3,\n                 embed_dim=96,\n                 depths=[2, 2, 6, 2],\n                 mlp_ratio=4.,\n                 drop_rate=0.,\n                 drop_path_rate=0.2,\n                 norm_layer=nn.LayerNorm,\n                 patch_norm=True,\n                 out_indices=[0, 1, 2, 3],\n                 frozen_stages=-1,\n                 focal_levels=[2,2,2,2], \n                 focal_windows=[9,9,9,9],\n                 use_pre_norms=[False, False, False, False], \n                 use_conv_embed=False, \n                 use_postln=False, \n                 use_postln_in_modulation=False, \n                 scaling_modulator=False,\n                 use_layerscale=False, \n                 use_checkpoint=False, \n        ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None, \n            use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                mlp_ratio=mlp_ratio,\n                drop=drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,\n                focal_window=focal_windows[i_layer], \n                focal_level=focal_levels[i_layer], \n                use_pre_norm=use_pre_norms[i_layer], \n                use_conv_embed=use_conv_embed,\n                use_postln=use_postln, \n                use_postln_in_modulation=use_postln_in_modulation,\n                scaling_modulator=scaling_modulator,\n                use_layerscale=use_layerscale, \n                use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features        \n        # self.norm = norm_layer(num_features[-1])\n\n        # add a norm layer for each output\n        for i_layer in self.out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f'norm{i_layer}'\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n        if isinstance(pretrained, str):\n            self.apply(_init_weights)\n            logger = get_root_logger()\n            load_checkpoint(self, pretrained, strict=False, logger=logger)\n        elif pretrained is None:\n            self.apply(_init_weights)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n\n        missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]\n        logger.info(f'=> Missed keys {missed_dict}')\n        unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]\n        logger.info(f'=> Unexpected keys {unexpected_dict}')\n\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        \n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                and 'relative_position_index' not in k\n                and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():\n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    fsize1 = table_pretrained.shape[2]\n                    fsize2 = table_current.shape[2]\n\n                    # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv\n                    if fsize1 < fsize2:\n                        table_pretrained_resized = torch.zeros(table_current.shape)\n                        table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained\n                        v = table_pretrained_resized\n                    elif fsize1 > fsize2:\n                        table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]\n                        v = table_pretrained_resized\n\n\n                if (\"modulation.f\" in k or \"pre_conv\" in k): \n                    table_pretrained = v\n                    table_current = model_dict[k]\n                    if table_pretrained.shape != table_current.shape:\n                        if len(table_pretrained.shape) == 2:\n                            dim = table_pretrained.shape[1]\n                            assert table_current.shape[1] == dim\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError\n                        elif len(table_pretrained.shape) == 1:\n                            dim = table_pretrained.shape[0]\n                            L1 = table_pretrained.shape[0]\n                            L2 = table_current.shape[0]\n                            if L1 < L2:\n                                table_pretrained_resized = torch.zeros(table_current.shape)\n                                # copy for linear project\n                                table_pretrained_resized[:dim] = table_pretrained[:dim]\n                                # copy for global token gating\n                                table_pretrained_resized[-1] = table_pretrained[-1]\n                                # copy for first multiple focal levels\n                                # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]\n                                # reassign pretrained weights\n                                v = table_pretrained_resized\n                            elif L1 > L2:\n                                raise NotImplementedError    \n\n                need_init_state_dict[k] = v\n        \n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        tic = time.time()\n        x = self.patch_embed(x)\n        Wh, Ww = x.size(2), x.size(3)\n\n        x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n            if i in self.out_indices:\n                norm_layer = getattr(self, f'norm{i}')\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n                \n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n\n        toc = time.time()\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(FocalNet, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2FocalNet(FocalNet, Backbone):\n    def __init__(self, cfg, input_shape):\n\n        pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']\n        patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']\n        in_chans = 3\n        embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']\n        depths = cfg['BACKBONE']['FOCAL']['DEPTHS']\n        mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']\n        drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']\n        drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']\n        norm_layer = nn.LayerNorm\n        patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']\n        use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']\n        out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']\n        scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)\n\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            mlp_ratio,\n            drop_rate,\n            drop_path_rate,\n            norm_layer,\n            patch_norm,\n            out_indices,\n            focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],\n            focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],   \n            use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],    \n            use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],       \n            use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], \n            scaling_modulator=scaling_modulator,\n            use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], \n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n@register_backbone\ndef get_focal_backbone(cfg):\n    focal = D2FocalNet(cfg['MODEL'], 224)    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        logger.info(f'=> init from {filename}')\n        with PathManager.open(filename, \"rb\") as f:\n            ckpt = torch.load(f)['model']\n        focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return focal"
  },
  {
    "path": "llava/model/semsam/backbone/registry.py",
    "content": "_model_entrypoints = {}\n\n\ndef register_backbone(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints\n"
  },
  {
    "path": "llava/model/semsam/backbone/swin.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu, Yutong Lin, Yixuan Wei\n# --------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.modeling import Backbone, ShapeSpec\nfrom detectron2.utils.file_io import PathManager\n\nfrom .registry import register_backbone\n\nlogger = logging.getLogger(__name__)\n\n\nclass Mlp(nn.Module):\n    \"\"\"Multilayer perceptron.\"\"\"\n\n    def __init__(\n        self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    \"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        window_size,\n        num_heads,\n        qkv_bias=True,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)\n        )  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=0.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"Forward function.\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n        \n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)\n        ].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1\n        )  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1\n        ).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    \"\"\"Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        window_size=7,\n        shift_size=0,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop\n        )\n\n        self.H = None\n        self.W = None\n\n    def forward(self, x, mask_matrix):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n            mask_matrix: Attention mask for cyclic shift.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        # HACK model will not upsampling\n        # if min([H, W]) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            # self.shift_size = 0\n            # self.window_size = min([H,W])\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_l = pad_t = 0\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n        _, Hp, Wp, _ = x.shape\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n            attn_mask = mask_matrix\n        else:\n            shifted_x = x\n            attn_mask = None\n\n        # partition windows\n        x_windows = window_partition(\n            shifted_x, self.window_size\n        )  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(\n            -1, self.window_size * self.window_size, C\n        )  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b > 0:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchMerging(nn.Module):\n    \"\"\"Patch Merging Layer\n    Args:\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C)\n\n        # padding\n        pad_input = (H % 2 == 1) or (W % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        num_heads (int): Number of attention head.\n        window_size (int): Local window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        depth,\n        num_heads,\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        norm_layer=nn.LayerNorm,\n        downsample=None,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n        self.window_size = window_size\n        self.shift_size = window_size // 2\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList(\n            [\n                SwinTransformerBlock(\n                    dim=dim,\n                    num_heads=num_heads,\n                    window_size=window_size,\n                    shift_size=0 if (i % 2 == 0) else window_size // 2,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n\n        # calculate attention mask for SW-MSA\n        Hp = int(np.ceil(H / self.window_size)) * self.window_size\n        Wp = int(np.ceil(W / self.window_size)) * self.window_size\n        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1\n        h_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        w_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(\n            img_mask, self.window_size\n        )  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(\n            attn_mask == 0, float(0.0)\n        ).type(x.dtype)\n        \n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, attn_mask)\n            else:\n                x = blk(x, attn_mask)\n        if self.downsample is not None:\n            x_down = self.downsample(x, H, W)\n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Image to Patch Embedding\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        # padding\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass SwinTransformer(nn.Module):\n    \"\"\"Swin Transformer backbone.\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        num_heads (tuple[int]): Number of attention head of each stage.\n        window_size (int): Window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n        drop_rate (float): Dropout rate.\n        attn_drop_rate (float): Attention dropout rate. Default: 0.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        pretrain_img_size=224,\n        patch_size=4,\n        in_chans=3,\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n        norm_layer=nn.LayerNorm,\n        ape=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=-1,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,\n        )\n\n        # absolute position embedding\n        if self.ape:\n            pretrain_img_size = to_2tuple(pretrain_img_size)\n            patch_size = to_2tuple(patch_size)\n            patches_resolution = [\n                pretrain_img_size[0] // patch_size[0],\n                pretrain_img_size[1] // patch_size[1],\n            ]\n\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])\n            )\n            trunc_normal_(self.absolute_pos_embed, std=0.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint,\n            )\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f\"norm{i_layer}\"\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 1 and self.ape:\n            self.absolute_pos_embed.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=0.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n\n    def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):\n        model_dict = self.state_dict()\n        pretrained_dict = {\n            k: v for k, v in pretrained_dict.items()\n            if k in model_dict.keys()\n        }\n        need_init_state_dict = {}\n        for k, v in pretrained_dict.items():\n            need_init = (\n                    (\n                            k.split('.')[0] in pretrained_layers\n                            or pretrained_layers[0] == '*'\n                    )\n                    and 'relative_position_index' not in k\n                    and 'attn_mask' not in k\n            )\n\n            if need_init:\n                # if verbose:\n                #     logger.info(f'=> init {k} from {pretrained}')\n\n                if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():\n                    relative_position_bias_table_pretrained = v\n                    relative_position_bias_table_current = model_dict[k]\n                    L1, nH1 = relative_position_bias_table_pretrained.size()\n                    L2, nH2 = relative_position_bias_table_current.size()\n                    if nH1 != nH2:\n                        logger.info(f\"Error in loading {k}, passing\")\n                    else:\n                        if L1 != L2:\n                            logger.info(\n                                '=> load_pretrained: resized variant: {} to {}'\n                                    .format((L1, nH1), (L2, nH2))\n                            )\n                            S1 = int(L1 ** 0.5)\n                            S2 = int(L2 ** 0.5)\n                            relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n                                relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),\n                                size=(S2, S2),\n                                mode='bicubic')\n                            v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n                if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():\n                    absolute_pos_embed_pretrained = v\n                    absolute_pos_embed_current = model_dict[k]\n                    _, L1, C1 = absolute_pos_embed_pretrained.size()\n                    _, L2, C2 = absolute_pos_embed_current.size()\n                    if C1 != C1:\n                        logger.info(f\"Error in loading {k}, passing\")\n                    else:\n                        if L1 != L2:\n                            logger.info(\n                                '=> load_pretrained: resized variant: {} to {}'\n                                    .format((1, L1, C1), (1, L2, C2))\n                            )\n                            S1 = int(L1 ** 0.5)\n                            S2 = int(L2 ** 0.5)\n                            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)\n                            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)\n                            absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(\n                                absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')\n                            v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)\n\n                need_init_state_dict[k] = v\n        self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        x = self.patch_embed(x)\n\n        Wh, Ww = x.size(2), x.size(3)\n        if self.ape:\n            # interpolate the position embedding to the corresponding size\n            absolute_pos_embed = F.interpolate(\n                self.absolute_pos_embed, size=(Wh, Ww), mode=\"bicubic\"\n            )\n            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C\n        else:\n            x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n\n            if i in self.out_indices:\n                norm_layer = getattr(self, f\"norm{i}\")\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n\n        if len(self.out_indices) == 0:\n            outs[\"res5\"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n        \n\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(SwinTransformer, self).train(mode)\n        self._freeze_stages()\n\n\nclass D2SwinTransformer(SwinTransformer, Backbone):\n    def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, \n                 depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale,\n                 drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, \n                 patch_norm, out_indices, use_checkpoint):\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            num_heads,\n            window_size,\n            mlp_ratio,\n            qkv_bias,\n            qk_scale,\n            drop_rate,\n            attn_drop_rate,\n            drop_path_rate,\n            norm_layer,\n            ape,\n            patch_norm,\n            out_indices,\n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg['OUT_FEATURES']\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features))\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in feature_names\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n\n\n@register_backbone\ndef get_swin_backbone(cfg):\n    swin_cfg = cfg['MODEL']['BACKBONE']['SWIN']\n\n    pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE']\n    patch_size = swin_cfg['PATCH_SIZE']\n    in_chans = 3\n    embed_dim = swin_cfg['EMBED_DIM']\n    depths = swin_cfg['DEPTHS']\n    num_heads = swin_cfg['NUM_HEADS']\n    window_size = swin_cfg['WINDOW_SIZE']\n    mlp_ratio = swin_cfg['MLP_RATIO']\n    qkv_bias = swin_cfg['QKV_BIAS']\n    qk_scale = swin_cfg['QK_SCALE']\n    drop_rate = swin_cfg['DROP_RATE']\n    attn_drop_rate = swin_cfg['ATTN_DROP_RATE']\n    drop_path_rate = swin_cfg['DROP_PATH_RATE']\n    norm_layer = nn.LayerNorm\n    ape = swin_cfg['APE']\n    patch_norm = swin_cfg['PATCH_NORM']\n    use_checkpoint = swin_cfg['USE_CHECKPOINT']\n    out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3])\n    \n    swin = D2SwinTransformer(\n        swin_cfg,\n        pretrain_img_size,\n        patch_size,\n        in_chans,\n        embed_dim,\n        depths,\n        num_heads,\n        window_size,\n        mlp_ratio,\n        qkv_bias,\n        qk_scale,\n        drop_rate,\n        attn_drop_rate,\n        drop_path_rate,\n        norm_layer,\n        ape,\n        patch_norm,\n        out_indices,\n        use_checkpoint=use_checkpoint,\n    )    \n\n    if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:\n        filename = cfg['MODEL']['BACKBONE']['PRETRAINED']\n        with PathManager.open(filename, \"rb\") as f:\n            # ckpt = torch.load(f, map_location=cfg['device'])['model']\n            ckpt = torch.load(f, map_location='cpu')['model']\n        swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])\n\n    return swin"
  },
  {
    "path": "llava/model/semsam/backbone/swin_new.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu, Yutong Lin, Yixuan Wei\n# --------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec\n\n\nclass Mlp(nn.Module):\n    \"\"\"Multilayer perceptron.\"\"\"\n\n    def __init__(\n        self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    \"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        window_size,\n        num_heads,\n        qkv_bias=True,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)\n        )  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=0.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"Forward function.\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)\n        ].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1\n        )  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1\n        ).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    \"\"\"Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        window_size=7,\n        shift_size=0,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop\n        )\n\n        self.H = None\n        self.W = None\n\n    def forward(self, x, mask_matrix):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n            mask_matrix: Attention mask for cyclic shift.\n        \"\"\"\n        B, L, C = x.shape\n        H, W = self.H, self.W\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # pad feature maps to multiples of window size\n        pad_l = pad_t = 0\n        pad_r = (self.window_size - W % self.window_size) % self.window_size\n        pad_b = (self.window_size - H % self.window_size) % self.window_size\n        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n        _, Hp, Wp, _ = x.shape\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n            attn_mask = mask_matrix\n        else:\n            shifted_x = x\n            attn_mask = None\n\n        # partition windows\n        x_windows = window_partition(\n            shifted_x, self.window_size\n        )  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(\n            -1, self.window_size * self.window_size, C\n        )  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if pad_r > 0 or pad_b > 0:\n            x = x[:, :H, :W, :].contiguous()\n\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchMerging(nn.Module):\n    \"\"\"Patch Merging Layer\n    Args:\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C)\n\n        # padding\n        pad_input = (H % 2 == 1) or (W % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of feature channels\n        depth (int): Depths of this stage.\n        num_heads (int): Number of attention head.\n        window_size (int): Local window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        depth,\n        num_heads,\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        norm_layer=nn.LayerNorm,\n        downsample=None,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n        self.window_size = window_size\n        self.shift_size = window_size // 2\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList(\n            [\n                SwinTransformerBlock(\n                    dim=dim,\n                    num_heads=num_heads,\n                    window_size=window_size,\n                    shift_size=0 if (i % 2 == 0) else window_size // 2,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, H, W):\n        \"\"\"Forward function.\n        Args:\n            x: Input feature, tensor size (B, H*W, C).\n            H, W: Spatial resolution of the input feature.\n        \"\"\"\n\n        # calculate attention mask for SW-MSA\n        Hp = int(np.ceil(H / self.window_size)) * self.window_size\n        Wp = int(np.ceil(W / self.window_size)) * self.window_size\n        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1\n        h_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        w_slices = (\n            slice(0, -self.window_size),\n            slice(-self.window_size, -self.shift_size),\n            slice(-self.shift_size, None),\n        )\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(\n            img_mask, self.window_size\n        )  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(\n            attn_mask == 0, float(0.0)\n        )\n\n        for blk in self.blocks:\n            blk.H, blk.W = H, W\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, attn_mask)\n            else:\n                x = blk(x, attn_mask)\n        if self.downsample is not None:\n            x_down = self.downsample(x, H, W)\n            Wh, Ww = (H + 1) // 2, (W + 1) // 2\n            return x, H, W, x_down, Wh, Ww\n        else:\n            return x, H, W, x, H, W\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Image to Patch Embedding\n    Args:\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        self.patch_size = patch_size\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        # padding\n        _, _, H, W = x.size()\n        if W % self.patch_size[1] != 0:\n            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))\n        if H % self.patch_size[0] != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))\n\n        x = self.proj(x)  # B C Wh Ww\n        if self.norm is not None:\n            Wh, Ww = x.size(2), x.size(3)\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)\n\n        return x\n\n\nclass SwinTransformer(nn.Module):\n    \"\"\"Swin Transformer backbone.\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        pretrain_img_size (int): Input image size for training the pretrained model,\n            used in absolute postion embedding. Default 224.\n        patch_size (int | tuple(int)): Patch size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        depths (tuple[int]): Depths of each Swin Transformer stage.\n        num_heads (tuple[int]): Number of attention head of each stage.\n        window_size (int): Window size. Default: 7.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.\n        drop_rate (float): Dropout rate.\n        attn_drop_rate (float): Attention dropout rate. Default: 0.\n        drop_path_rate (float): Stochastic depth rate. Default: 0.2.\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).\n            -1 means not freezing any parameters.\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        pretrain_img_size=224,\n        patch_size=4,\n        in_chans=3,\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n        norm_layer=nn.LayerNorm,\n        ape=False,\n        patch_norm=True,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=-1,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n\n        self.pretrain_img_size = pretrain_img_size\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,\n        )\n\n        # absolute position embedding\n        if self.ape:\n            pretrain_img_size = to_2tuple(pretrain_img_size)\n            patch_size = to_2tuple(patch_size)\n            patches_resolution = [\n                pretrain_img_size[0] // patch_size[0],\n                pretrain_img_size[1] // patch_size[1],\n            ]\n\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])\n            )\n            trunc_normal_(self.absolute_pos_embed, std=0.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint,\n            )\n            self.layers.append(layer)\n\n        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        self.num_features = num_features\n\n        # add a norm layer for each output\n        for i_layer in out_indices:\n            layer = norm_layer(num_features[i_layer])\n            layer_name = f\"norm{i_layer}\"\n            self.add_module(layer_name, layer)\n\n        self._freeze_stages()\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            self.patch_embed.eval()\n            for param in self.patch_embed.parameters():\n                param.requires_grad = False\n\n        if self.frozen_stages >= 1 and self.ape:\n            self.absolute_pos_embed.requires_grad = False\n\n        if self.frozen_stages >= 2:\n            self.pos_drop.eval()\n            for i in range(0, self.frozen_stages - 1):\n                m = self.layers[i]\n                m.eval()\n                for param in m.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        \"\"\"Initialize the weights in backbone.\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n\n        def _init_weights(m):\n            if isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=0.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.bias, 0)\n                nn.init.constant_(m.weight, 1.0)\n\n    def forward(self, x):\n        \"\"\"Forward function.\"\"\"\n        x = self.patch_embed(x)\n\n        Wh, Ww = x.size(2), x.size(3)\n        if self.ape:\n            # interpolate the position embedding to the corresponding size\n            absolute_pos_embed = F.interpolate(\n                self.absolute_pos_embed, size=(Wh, Ww), mode=\"bicubic\"\n            )\n            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C\n        else:\n            x = x.flatten(2).transpose(1, 2)\n        x = self.pos_drop(x)\n\n        outs = {}\n        for i in range(self.num_layers):\n            layer = self.layers[i]\n            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)\n\n            if i in self.out_indices:\n                norm_layer = getattr(self, f\"norm{i}\")\n                x_out = norm_layer(x_out)\n\n                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()\n                outs[\"res{}\".format(i + 2)] = out\n\n        return outs\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep layers freezed.\"\"\"\n        super(SwinTransformer, self).train(mode)\n        self._freeze_stages()\n\n\n@BACKBONE_REGISTRY.register()\nclass D2SwinTransformer(SwinTransformer, Backbone):\n    def __init__(self, cfg, input_shape):\n\n        pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE\n        patch_size = cfg.MODEL.SWIN.PATCH_SIZE\n        in_chans = 3\n        embed_dim = cfg.MODEL.SWIN.EMBED_DIM\n        depths = cfg.MODEL.SWIN.DEPTHS\n        num_heads = cfg.MODEL.SWIN.NUM_HEADS\n        window_size = cfg.MODEL.SWIN.WINDOW_SIZE\n        mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO\n        qkv_bias = cfg.MODEL.SWIN.QKV_BIAS\n        qk_scale = cfg.MODEL.SWIN.QK_SCALE\n        drop_rate = cfg.MODEL.SWIN.DROP_RATE\n        attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE\n        drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE\n        norm_layer = nn.LayerNorm\n        ape = cfg.MODEL.SWIN.APE\n        patch_norm = cfg.MODEL.SWIN.PATCH_NORM\n        use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT\n\n        super().__init__(\n            pretrain_img_size,\n            patch_size,\n            in_chans,\n            embed_dim,\n            depths,\n            num_heads,\n            window_size,\n            mlp_ratio,\n            qkv_bias,\n            qk_scale,\n            drop_rate,\n            attn_drop_rate,\n            drop_path_rate,\n            norm_layer,\n            ape,\n            patch_norm,\n            use_checkpoint=use_checkpoint,\n        )\n\n        self._out_features = cfg.MODEL.SWIN.OUT_FEATURES\n\n        self._out_feature_strides = {\n            \"res2\": 4,\n            \"res3\": 8,\n            \"res4\": 16,\n            \"res5\": 32,\n        }\n        self._out_feature_channels = {\n            \"res2\": self.num_features[0],\n            \"res3\": self.num_features[1],\n            \"res4\": self.num_features[2],\n            \"res5\": self.num_features[3],\n        }\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        assert (\n            x.dim() == 4\n        ), f\"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        outputs = {}\n        y = super().forward(x)\n        for k in y.keys():\n            if k in self._out_features:\n                outputs[k] = y[k]\n        return outputs\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n    @property\n    def size_divisibility(self):\n        return 32\n"
  },
  {
    "path": "llava/model/semsam/body/__init__.py",
    "content": "from .build import build_openseed_head"
  },
  {
    "path": "llava/model/semsam/body/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\nfrom .openseed_head import *\n\n\ndef build_openseed_head(config, *args, **kwargs):\n    model_name = config['MODEL']['HEAD']\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    body = model_entrypoints(model_name)(config, *args, **kwargs)\n    return body"
  },
  {
    "path": "llava/model/semsam/body/decoder/__init__.py",
    "content": "from .build import build_decoder\nfrom .idino_decoder_no_iou_token_partwhole_all_llm import *"
  },
  {
    "path": "llava/model/semsam/body/decoder/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\n\ndef build_decoder(config, *args, **kwargs):\n    model_name = config['MODEL']['DECODER']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, *args, **kwargs)"
  },
  {
    "path": "llava/model/semsam/body/decoder/idino_decoder_no_iou_token_partwhole_all_llm.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang.\nimport logging\nimport fvcore.nn.weight_init as weight_init\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom detectron2.layers import Conv2d\nfrom detectron2.utils.registry import Registry\nfrom detectron2.structures import BitMasks\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_decoder\nfrom .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer\nfrom .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid\nfrom ...utils import box_ops\nfrom ...utils import configurable\n\nclass MaskDINODecoder(nn.Module):\n    @configurable\n    def __init__(\n            self,\n            lang_encoder: nn.Module,\n            in_channels,\n            mask_classification=True,\n            *,\n            num_classes: int,\n            hidden_dim: int,\n            dim_proj: int,\n            num_queries: int,\n            nheads: int,\n            dim_feedforward: int,\n            dec_layers: int,\n            mask_dim: int,\n            enforce_input_project: bool,\n            two_stage: bool,\n            dn: str,\n            noise_scale:float,\n            dn_num:int,\n            initialize_box_type:bool,\n            initial_pred:bool,\n            learn_tgt: bool,\n            total_num_feature_levels: int = 4,\n            dropout: float = 0.0,\n            activation: str = 'relu',\n            nhead: int = 8,\n            dec_n_points: int = 4,\n            return_intermediate_dec: bool = True,\n            query_dim: int = 4,\n            dec_layer_share: bool = False,\n            semantic_ce_loss: bool = False,\n            num_mask_tokens: int = 3,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            in_channels: channels of the input features\n            mask_classification: whether to add mask classifier or not\n            num_classes: number of classes\n            hidden_dim: Transformer feature dimension\n            num_queries: number of queries\n            nheads: number of heads\n            dim_feedforward: feature dimension in feedforward network\n            enc_layers: number of Transformer encoder layers\n            dec_layers: number of Transformer decoder layers\n            pre_norm: whether to use pre-LayerNorm or not\n            mask_dim: mask feature dimension\n            enforce_input_project: add input project 1x1 conv even if input\n                channels and hidden dim is identical\n            d_model: transformer dimension\n            dropout: dropout rate\n            activation: activation function\n            nhead: num heads in multi-head attention\n            dec_n_points: number of sampling points in decoder\n            return_intermediate_dec: return the intermediate results of decoder\n            query_dim: 4 -> (x, y, w, h)\n            dec_layer_share: whether to share each decoder layer\n            semantic_ce_loss: use ce loss for semantic segmentation\n        \"\"\"\n        super().__init__()\n\n        assert mask_classification, \"Only support mask classification model\"\n        self.mask_classification = mask_classification\n        self.num_feature_levels = total_num_feature_levels\n        self.initial_pred = initial_pred\n\n        # define Transformer decoder here\n        self.dn=dn\n        self.learn_tgt = learn_tgt\n        self.noise_scale=noise_scale\n        self.dn_num=dn_num\n        self.num_heads = nheads\n        self.num_layers = dec_layers\n        self.two_stage=two_stage\n        self.initialize_box_type = initialize_box_type\n        self.total_num_feature_levels = total_num_feature_levels\n\n        self.num_queries = num_queries\n        \n        self.semantic_ce_loss = semantic_ce_loss\n        interactive_only = True\n        # learnable query features\n        if num_queries>0 and not interactive_only:\n            if not two_stage or self.learn_tgt:\n                self.query_feat = nn.Embedding(num_queries, hidden_dim)\n            if not two_stage and initialize_box_type == 'no':\n                self.query_embed = nn.Embedding(num_queries, 4)\n        # if two_stage:\n        #     self.enc_output = nn.Linear(hidden_dim, hidden_dim)\n        #     self.enc_output_norm = nn.LayerNorm(hidden_dim)\n\n        self.input_proj = nn.ModuleList()\n        for _ in range(self.num_feature_levels):\n            if in_channels != hidden_dim or enforce_input_project:\n                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))\n                weight_init.c2_xavier_fill(self.input_proj[-1])\n            else:\n                self.input_proj.append(nn.Sequential())\n        self.num_classes=num_classes\n        # output FFNs\n        assert self.mask_classification, \"why not class embedding?\"\n        # self.label_enc=nn.Embedding(505, hidden_dim)  # this is a hack for o365+coco (365+133=498)\n        self.dim_proj = dim_proj\n        self.lang_encoder = lang_encoder\n        # if lang_encoder is not None:\n        self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim))\n        trunc_normal_(self.lang_mapper, std=.02)\n\n        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)\n\n        # init decoder\n        self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim)\n        decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward,\n                                                          dropout, activation,\n                                                          self.num_feature_levels, nhead, dec_n_points)\n        self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm,\n                                          return_intermediate=return_intermediate_dec,\n                                          d_model=hidden_dim, query_dim=query_dim,\n                                          num_feature_levels=self.num_feature_levels,\n                                          dec_layer_share=dec_layer_share,\n                                          )\n\n        self.hidden_dim = hidden_dim\n        self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)\n        nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)\n        nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)\n        box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)]  # share box prediction each layer\n        self.bbox_embed = nn.ModuleList(box_embed_layerlist)\n        self.decoder.bbox_embed = self.bbox_embed\n        \n        # whole category classification\n        self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))\n        trunc_normal_(self.class_embed, std=.02)\n        # part category classification\n        self.class_embed_part = nn.Parameter(torch.empty(hidden_dim, dim_proj))\n        trunc_normal_(self.class_embed_part, std=.02)\n\n        # FIXME iou head; iou prediction: 1. iou token to predict 3 score. 2. predict each iou score from query tokens\n        # FIXME seems we only need to stack these tokens in batch dimension to reduce self attention burden.\n        self.num_mask_tokens = num_mask_tokens  # sam uses 4 to handle multi prompts\n        self.iou_token = 0   # FIXME hack to remove iou token\n        self.num_all_tokens = self.num_mask_tokens + self.iou_token  # sam uses 4 to handle multi prompts\n        self.iou_prediction_head = MLP(hidden_dim, hidden_dim, 1, 3)\n        # self.iou_token = nn.Embedding(self.iou_token, hidden_dim)\n        self.mask_tokens = nn.Embedding(self.num_mask_tokens, hidden_dim)\n        self.pb_embedding=nn.Embedding(2,hidden_dim)\n        self.label_enc=nn.Embedding(2,hidden_dim)\n        \n        self.prediction_switch = None\n\n\n    @classmethod\n    def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):\n        ret = {}\n        ret[\"in_channels\"] = in_channels\n        ret[\"lang_encoder\"] = lang_encoder\n        ret[\"mask_classification\"] = mask_classification\n\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret[\"num_classes\"] = enc_cfg['NUM_CLASSES']\n        ret[\"hidden_dim\"] = dec_cfg['HIDDEN_DIM']\n        ret[\"dim_proj\"] = cfg['MODEL']['DIM_PROJ']\n        ret[\"num_queries\"] = dec_cfg['NUM_OBJECT_QUERIES']\n\n        # Transformer parameters:\n        ret[\"num_mask_tokens\"] = dec_cfg.get('NUM_MASK_TOKENS', 3)\n        \n        ret[\"nheads\"] = dec_cfg['NHEADS']\n        ret[\"dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']\n        ret[\"dec_layers\"] = dec_cfg['DEC_LAYERS']\n        ret[\"enforce_input_project\"] = dec_cfg['ENFORCE_INPUT_PROJ']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"two_stage\"] = dec_cfg['TWO_STAGE']\n        ret[\"initialize_box_type\"] = dec_cfg['INITIALIZE_BOX_TYPE']  # ['no', 'bitmask', 'mask2box']\n        ret[\"dn\"] = dec_cfg['DN']\n        ret[\"noise_scale\"] = dec_cfg['DN_NOISE_SCALE']\n        ret[\"dn_num\"] = dec_cfg['DN_NUM']\n        ret[\"initial_pred\"] = dec_cfg['INITIAL_PRED']\n        ret[\"learn_tgt\"] = dec_cfg['LEARN_TGT']\n        ret[\"total_num_feature_levels\"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS']\n        ret[\"num_mask_tokens\"] = dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3)\n        ret[\"semantic_ce_loss\"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON']\n\n        return ret\n\n    def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size):\n        \"\"\"\n        modified from dn-detr. You can refer to dn-detr\n        https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py\n        for more details\n            :param dn_args: scalar, noise_scale\n            :param tgt: original tgt (content) in the matching part\n            :param refpoint_emb: positional anchor queries in the matching part\n            :param batch_size: bs\n            \"\"\"\n        if self.training:\n            scalar, noise_scale = self.dn_num, self.noise_scale\n\n            known = [(torch.ones_like(t['labels'])).cuda() for t in targets]\n            know_idx = [torch.nonzero(t) for t in known]\n            known_num = [sum(k) for k in known]\n\n            # use fix number of dn queries\n            if max(known_num) > 0:\n                scalar = scalar // (int(max(known_num)))\n            else:\n                scalar = 0\n            if scalar == 0:\n                input_query_label = None\n                input_query_bbox = None\n                attn_mask = None\n                mask_dict = None\n                return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n            # can be modified to selectively denosie some label or boxes; also known label prediction\n            unmask_bbox = unmask_label = torch.cat(known)\n            labels = torch.cat([t['labels'] for t in targets])\n            # use languge as denosing content queries.\n            # if task == 'det':\n            #     labels = labels  # o365 start from 133 class\n            boxes = torch.cat([t['boxes'] for t in targets])\n            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])\n            # known\n            known_indice = torch.nonzero(unmask_label + unmask_bbox)\n            known_indice = known_indice.view(-1)\n\n            # noise\n            known_indice = known_indice.repeat(scalar, 1).view(-1)\n            known_labels = labels.repeat(scalar, 1).view(-1)\n            known_bid = batch_idx.repeat(scalar, 1).view(-1)\n            known_bboxs = boxes.repeat(scalar, 1)\n            known_labels_expaned = known_labels.clone()\n            known_bbox_expand = known_bboxs.clone()\n\n            if noise_scale > 0:\n                diff = torch.zeros_like(known_bbox_expand)\n                diff[:, :2] = known_bbox_expand[:, 2:] / 2\n                diff[:, 2:] = known_bbox_expand[:, 2:]\n                known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),\n                                               diff).cuda() * noise_scale\n                known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)\n\n            m = known_labels_expaned.long().to('cuda')\n            # import ipdb; ipdb.set_trace()\n            input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0,\n                                             m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper\n\n            input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n            single_pad = int(max(known_num))\n            pad_size = int(single_pad * scalar)\n\n            padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim)\n            padding_bbox = input_bbox_embed.new_zeros(pad_size, 4)\n\n            if not refpoint_emb is None:\n                input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1)\n                input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1)\n            else:\n                input_query_label = padding_label.repeat(batch_size, 1, 1)\n                input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)\n\n            # map\n            map_known_indice = input_label_embed.new_tensor([])\n            if len(known_num):\n                map_known_indice = torch.cat(\n                    [input_label_embed.new_tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]\n                map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long()\n            if len(known_bid):\n                input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed\n                input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed\n\n            tgt_size = pad_size + self.num_queries\n            attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0\n            # match query cannot see the reconstruct\n            attn_mask[pad_size:, :pad_size] = True\n            # reconstruct cannot see each other\n            for i in range(scalar):\n                if i == 0:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                if i == scalar - 1:\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n                else:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n            mask_dict = {\n                'known_indice': torch.as_tensor(known_indice).long(),\n                'batch_idx': torch.as_tensor(batch_idx).long(),\n                'map_known_indice': torch.as_tensor(map_known_indice).long(),\n                'known_lbs_bboxes': (known_labels, known_bboxs),\n                'know_idx': know_idx,\n                'pad_size': pad_size,\n                'scalar': scalar,\n            }\n        else:\n            if not refpoint_emb is None:\n                input_query_label = tgt.repeat(batch_size, 1, 1)\n                input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1)\n            else:\n                input_query_label = None\n                input_query_bbox = None\n            attn_mask = None\n            mask_dict = None\n\n        # 100*batch*256\n        if not input_query_bbox is None:\n            input_query_label = input_query_label\n            input_query_bbox = input_query_bbox\n\n        return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n    def prepare_for_dn_o3(self, targets, tgt, refpoint_emb, batch_size):\n        \"\"\"\n        modified from dn-detr. You can refer to dn-detr\n        https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py\n        for more details\n            :param dn_args: scalar, noise_scale\n            :param tgt: original tgt (content) in the matching part\n            :param refpoint_emb: positional anchor queries in the matching part\n            :param batch_size: bs\n            \"\"\"\n        if self.training:\n            scalar, noise_scale = self.dn_num, self.noise_scale\n\n            known = [(torch.ones_like(t['labels'])).cuda() for t in targets]\n            know_idx = [torch.nonzero(t) for t in known]\n            known_num = [sum(k) for k in known]\n\n            # use fix number of dn queries\n            if max(known_num) > 0:\n                scalar = 1\n            else:\n                scalar = 0\n            if scalar == 0:\n                input_query_label = None\n                input_query_bbox = None\n                attn_mask = None\n                mask_dict = None\n                return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n            # can be modified to selectively denosie some label or boxes; also known label prediction\n            unmask_bbox = unmask_label = torch.cat(known)\n            labels = torch.cat([t['labels'] for t in targets])\n            # use languge as denosing content queries.\n            # if task == 'det':\n            #     labels = labels  # o365 start from 133 class\n            boxes = torch.cat([t['boxes'] for t in targets])\n            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])\n            # known\n            known_indice = torch.nonzero(unmask_label + unmask_bbox)\n            known_indice = known_indice.view(-1)\n\n            # noise\n            known_indice = known_indice.repeat(scalar, 1).view(-1)\n            known_labels = labels.repeat(scalar, 1).view(-1)\n            known_bid = batch_idx.repeat(scalar, 1).view(-1)\n            known_bboxs = boxes.repeat(scalar, 1)\n            known_labels_expaned = known_labels.clone()\n            known_bbox_expand = known_bboxs.clone()\n\n            if noise_scale > 0:\n                diff = torch.zeros_like(known_bbox_expand)\n                diff[:, :2] = known_bbox_expand[:, 2:] / 2\n                diff[:, 2:] = known_bbox_expand[:, 2:]\n                known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),\n                                               diff).cuda() * noise_scale\n                known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)\n\n            m = known_labels_expaned.long().to('cuda')\n            # import ipdb; ipdb.set_trace()\n            input_label_embed = self.pb_embedding(torch.ones_like(m))\n\n\n            input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n            single_pad = int(max(known_num))\n            pad_size = int(single_pad * scalar)\n\n            padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim)\n            padding_bbox = input_bbox_embed.new_zeros(pad_size, 4)\n\n            if not refpoint_emb is None:\n                input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1)\n                input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1)\n            else:\n                input_query_label = padding_label.repeat(batch_size, 1, 1)\n                input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)\n\n            # map\n            map_known_indice = input_label_embed.new_tensor([])\n            if len(known_num):\n                map_known_indice = torch.cat(\n                    [input_label_embed.new_tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]\n                map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long()\n            if len(known_bid):\n                input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed\n                input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed\n\n            tgt_size = pad_size + self.num_queries\n            attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0\n            # match query cannot see the reconstruct\n            attn_mask[pad_size:, :pad_size] = True\n            # reconstruct cannot see each other\n            for i in range(scalar):\n                if i == 0:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                if i == scalar - 1:\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n                else:\n                    attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                    attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n            mask_dict = {\n                'known_indice': torch.as_tensor(known_indice).long(),\n                'batch_idx': torch.as_tensor(batch_idx).long(),\n                'map_known_indice': torch.as_tensor(map_known_indice).long(),\n                'known_lbs_bboxes': (known_labels, known_bboxs),\n                'know_idx': know_idx,\n                'pad_size': pad_size,\n                'scalar': scalar,\n            }\n        else:\n            if not refpoint_emb is None:\n                input_query_label = tgt.repeat(batch_size, 1, 1)\n                input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1)\n            else:\n                input_query_label = None\n                input_query_bbox = None\n            attn_mask = None\n            mask_dict = None\n\n        # 100*batch*256\n        if not input_query_bbox is None:\n            input_query_label = input_query_label\n            input_query_bbox = input_query_bbox\n\n        return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n    def prepare_for_dn_mo(self, targets, tgt, refpoint_emb, batch_size):\n        # if self.training:\n        scalar, noise_scale = self.dn_num,self.noise_scale\n\n        known = [(torch.ones_like(t['boxes'])).cuda() for t in targets]\n        know_idx = [torch.nonzero(t) for t in known]\n        known_num = [k.sum() for k in known]\n\n        if max(known_num)>0:\n            scalar=1  # FIXME this is wrong attention mask!!!\n        else:\n            scalar=0\n        if scalar==0:\n            input_query_label = None\n            input_query_bbox = None\n            attn_mask = None\n            mask_dict = None\n\n        #     return input_query_label, input_query_bbox, attn_mask, mask_dict\n\n        pb_labels = torch.stack([t['pb'] for t in targets])\n        # FIXME this is for future content-based interaction; pool content features as label embedding\n        labels = torch.zeros_like(pb_labels).long()\n        boxes = torch.stack([t['boxes_dn'] for t in targets])\n        box_start = [t['box_start'] for t in targets]\n\n\n        known_labels = labels\n        known_pb_labels = pb_labels\n\n        known_bboxs = boxes\n        known_labels_expaned = known_labels.clone()\n        known_pb_labels_expaned = known_pb_labels.clone()\n        known_bbox_expand = known_bboxs.clone()\n\n        ############ noise on the label\n        # if noise_scale > 0:\n        #     p = torch.rand_like(known_labels_expaned.float())\n        #     chosen_indice = torch.nonzero(p < (noise_scale * 0.5)).view(-1)  # half of bbox prob\n        #     new_label = torch.randint_like(chosen_indice, 0, self.num_classes)  # randomly put a new one here\n        #     known_labels_expaned.scatter_(0, chosen_indice, new_label)\n        if noise_scale > 0 and self.training:\n            diff = torch.zeros_like(known_bbox_expand)\n            diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2\n            diff[:, :, 2:] = known_bbox_expand[:, :, 2:]\n            sc = 0.01\n            for i, st in enumerate(box_start):\n                diff[i, :st] = diff[i, :st] * sc\n            known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),\n                                           diff).cuda() * noise_scale\n            # known_bbox_expand+=(torch.rand_like(known_bbox_expand)*2-1.0)*torch.tensor([[1,1,0.1,0.1]]).cuda()*noise_scale\n            known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)\n\n        m = known_labels_expaned.long().to('cuda')\n        m_pb = known_pb_labels_expaned.long().to('cuda')\n        input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb)\n        input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n\n        input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1)\n        input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1)\n\n        single_pad = self.num_all_tokens\n\n        # NOTE scalar is modified to 100, each click cannot see each other\n        scalar = int(input_label_embed.shape[1]/self.num_all_tokens)\n\n        pad_size = input_label_embed.shape[1]\n\n        if input_label_embed.shape[1]>0:\n            input_query_label = input_label_embed\n            input_query_bbox = input_bbox_embed\n\n        tgt_size = pad_size\n        attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0\n        # match query cannot see the reconstruct\n        attn_mask[pad_size:, :pad_size] = True\n        # reconstruct cannot see each other\n        for i in range(scalar):\n            if i == 0:\n                attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n            if i == scalar - 1:\n                attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n            else:\n                attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True\n                attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True\n        mask_dict = {\n            'known_lbs_bboxes': (known_labels, known_bboxs),\n            # 'know_idx': know_idx,\n            'pad_size': pad_size,\n            'scalar': scalar,\n        }\n\n\n        # 100*batch*256\n        if not input_query_bbox is None:\n            input_query_label = input_query_label\n            input_query_bbox = input_query_bbox\n\n        return input_query_label,input_query_bbox,attn_mask,mask_dict\n\n    def prepare_for_dn_mo_infer(self, targets, tgt, refpoint_emb, batch_size):\n\n        known = [(torch.ones_like(t['points'])).cuda() for t in targets]\n        known_num = [k.sum() for k in known]\n\n        assert max(known_num)>0\n\n        pb_labels = torch.stack([t['pb'] for t in targets])\n        # FIXME this is for future content-based interaction; pool content features as label embedding\n        labels = torch.zeros_like(pb_labels).long()\n        boxes = torch.stack([t['points'] for t in targets])\n\n\n        known_labels = labels\n        known_pb_labels = pb_labels\n\n        known_bboxs = boxes\n        known_labels_expaned = known_labels.clone()\n        known_pb_labels_expaned = known_pb_labels.clone()\n        known_bbox_expand = known_bboxs.clone()\n\n        m = known_labels_expaned.long().to('cuda')\n        m_pb = known_pb_labels_expaned.long().to('cuda')\n        input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb)\n        input_bbox_embed = inverse_sigmoid(known_bbox_expand)\n\n        input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1)\n        input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1)\n\n\n        scalar = int(input_label_embed.shape[1]/self.num_all_tokens)\n\n        pad_size = input_label_embed.shape[1]\n\n        if input_label_embed.shape[1]>0:\n            input_query_label = input_label_embed\n            input_query_bbox = input_bbox_embed\n\n        attn_mask = None\n        mask_dict = {\n            'known_lbs_bboxes': (known_labels, known_bboxs),\n            # 'know_idx': know_idx,\n            'pad_size': pad_size,\n            'scalar': scalar,\n        }\n\n\n        return input_query_label,input_query_bbox,attn_mask,mask_dict\n\n    def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask):\n        \"\"\"\n            post process of dn after output from the transformer\n            put the dn part in the mask_dict\n            \"\"\"\n        assert mask_dict['pad_size'] > 0\n        output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :]\n        outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :]\n        output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :]\n        outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :]\n        output_known_mask = None\n        if outputs_mask is not None:\n            output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :]\n            outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :]\n        out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]}\n\n        out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord)\n        mask_dict['output_known_lbs_bboxes']=out\n        return outputs_class, outputs_coord, outputs_mask\n\n    def get_valid_ratio(self, mask):\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def pred_box(self, reference, hs, ref0=None):\n        \"\"\"\n        :param reference: reference box coordinates from each decoder layer\n        :param hs: content\n        :param ref0: whether there are prediction from the first layer\n        \"\"\"\n        if ref0 is None:\n            outputs_coord_list = []\n        else:\n            outputs_coord_list = [ref0]\n        for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):\n            layer_delta_unsig = layer_bbox_embed(layer_hs)\n            # layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)\n            new_layer_ref_sig = layer_ref_sig.view(layer_ref_sig.shape[0], -1, self.num_all_tokens, layer_ref_sig.shape[-1])\n            new_layer_ref_sig = new_layer_ref_sig[:, :, :self.num_mask_tokens].reshape(new_layer_ref_sig.shape[0], -1, new_layer_ref_sig.shape[-1])\n            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(new_layer_ref_sig)\n            layer_outputs_unsig = layer_outputs_unsig.sigmoid()\n            outputs_coord_list.append(layer_outputs_unsig)\n        outputs_coord_list = torch.stack(outputs_coord_list)\n        return outputs_coord_list\n\n    def pred_box_old(self, reference, hs, ref0=None):\n        \"\"\"\n        :param reference: reference box coordinates from each decoder layer\n        :param hs: content\n        :param ref0: whether there are prediction from the first layer\n        \"\"\"\n        if ref0 is None:\n            outputs_coord_list = []\n        else:\n            outputs_coord_list = [ref0]\n        for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):\n            layer_delta_unsig = layer_bbox_embed(layer_hs)\n            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)\n            layer_outputs_unsig = layer_outputs_unsig.sigmoid()\n            outputs_coord_list.append(layer_outputs_unsig)\n        outputs_coord_list = torch.stack(outputs_coord_list)\n        return outputs_coord_list\n\n    def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}):\n        \"\"\"\n        task: seg/det TODO add sam\n        \"\"\"\n        # task = 'sam'\n        prediction_switch = extra\n        self.prediction_switch = prediction_switch\n        assert len(x) == self.num_feature_levels\n        do_seg = (task != 'det')   # if task is det, not do segmentation training\n        size_list = []\n        # disable mask, it does not affect performance\n        enable_mask = 0\n        if masks is not None:\n            for src in x:\n                if src.size(2) % 32 or src.size(3) % 32:\n                    enable_mask = 1\n        if enable_mask == 0:\n            masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x]\n        src_flatten = []\n        mask_flatten = []\n        spatial_shapes = []\n        for i in range(self.num_feature_levels):\n            idx=self.num_feature_levels-1-i\n            bs, c , h, w=x[idx].shape\n            size_list.append(x[i].shape[-2:])\n            spatial_shapes.append(x[idx].shape[-2:])\n            src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2))\n            mask_flatten.append(masks[i].flatten(1))\n        src_flatten = torch.cat(src_flatten, 1)  # bs, \\sum{hxw}, c\n        mask_flatten = torch.cat(mask_flatten, 1)  # bs, \\sum{hxw}\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        predictions_class = []\n        predictions_class_part = []\n        predictions_mask = []\n        predictions_iou_score = []\n\n        tgt_mask = None\n        mask_dict = None\n        if self.dn != \"no\":\n            assert targets is not None\n            if task=='demo':\n                input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n                    self.prepare_for_dn_mo_infer(targets, None, None, x[0].shape[0])\n            else:\n                input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n                    self.prepare_for_dn_mo(targets, None, None, x[0].shape[0])\n            tgt=input_query_label\n            refpoint_embed=input_query_bbox\n            if tgt is None:\n                tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda()\n                refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda()\n        # import pdb;pdb.set_trace()\n        refpoint_embed=refpoint_embed.to(tgt.dtype)\n        hs, references = self.decoder(\n            tgt=tgt.transpose(0, 1),\n            memory=src_flatten.transpose(0, 1),\n            memory_key_padding_mask=mask_flatten,\n            pos=None,\n            refpoints_unsigmoid=refpoint_embed.transpose(0, 1),\n            level_start_index=level_start_index,\n            spatial_shapes=spatial_shapes,\n            valid_ratios=valid_ratios,\n            tgt_mask=tgt_mask\n        )\n\n        new_hs = []\n        feats=[]\n        for i, output in enumerate(hs):\n            outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output = self.idno_forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg)\n            outputs_class_whole, outputs_class_part = outputs_class\n            predictions_class.append(outputs_class_whole)\n            predictions_class_part.append(outputs_class_part)\n            predictions_mask.append(outputs_mask)\n            feats.append(decoder_output)\n            if iou_score is not None:\n                predictions_iou_score.append(iou_score)\n                new_hs.append(decoder_output_mask)\n        if new_hs is not None:\n            hs = new_hs\n        # iteratively box prediction\n\n        out_boxes = self.pred_box(references, hs)\n        out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum() \n                                                               + self.mask_tokens.weight.sum() + self.lang_mapper.sum()+iou_score.sum())\n        if mask_dict is not None:\n            if predictions_mask is None:\n                predictions_class[-1] = predictions_class[-1]\n                for i in range(self.mask_embed.num_layers):\n                    predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0])  # avoid no mask loss\n                predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n            if do_seg:\n                predictions_mask = list(predictions_mask)\n        elif self.training:  # this is to insure self.label_enc participate in the model\n            for i in range(self.mask_embed.num_layers):\n                predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                            self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                        0])  # avoid no mask loss\n            predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n        out = {\n            'pred_logits': predictions_class[-1],\n            'obj_features': feats[-1],\n            'pred_logits_part': predictions_class_part[-1],\n            'pred_masks': None if not do_seg else predictions_mask[-1],\n            'pred_boxes':out_boxes[-1],\n            'pred_ious': predictions_iou_score[-1],\n            'aux_outputs': self._set_aux_loss(\n                predictions_class if self.mask_classification else None, predictions_mask, out_boxes, predictions_iou_score, predictions_class_part\n            )\n        }\n\n        return out, mask_dict\n\n    def forward_o365(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}):\n        \"\"\"\n        task: seg/det TODO add sam\n        \"\"\"\n        # task = 'sam'\n        prediction_switch = extra\n        self.prediction_switch = prediction_switch\n        assert len(x) == self.num_feature_levels\n        do_seg = False   # if task is det, not do segmentation training\n        size_list = []\n        # disable mask, it does not affect performance\n        enable_mask = 0\n        if masks is not None:\n            for src in x:\n                if src.size(2) % 32 or src.size(3) % 32:\n                    enable_mask = 1\n        if enable_mask == 0:\n            masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x]\n        src_flatten = []\n        mask_flatten = []\n        spatial_shapes = []\n        for i in range(self.num_feature_levels):\n            idx=self.num_feature_levels-1-i\n            bs, c , h, w=x[idx].shape\n            size_list.append(x[i].shape[-2:])\n            spatial_shapes.append(x[idx].shape[-2:])\n            src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2))\n            mask_flatten.append(masks[i].flatten(1))\n        src_flatten = torch.cat(src_flatten, 1)  # bs, \\sum{hxw}, c\n        mask_flatten = torch.cat(mask_flatten, 1)  # bs, \\sum{hxw}\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        predictions_class = []\n        # predictions_class_part = []\n        predictions_mask = []\n        # predictions_iou_score = []\n\n        tgt_mask = None\n        mask_dict = None\n        # if self.dn != \"no\":\n        assert targets is not None\n        input_query_label, input_query_bbox, tgt_mask, mask_dict = \\\n            self.prepare_for_dn_o3(targets, None, None, x[0].shape[0])\n        tgt=input_query_label\n        refpoint_embed=input_query_bbox\n        if tgt is None:\n            tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda()\n            refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda()\n\n        hs, references = self.decoder(\n            tgt=tgt.transpose(0, 1),\n            memory=src_flatten.transpose(0, 1),\n            memory_key_padding_mask=mask_flatten,\n            pos=None,\n            refpoints_unsigmoid=refpoint_embed.transpose(0, 1),\n            level_start_index=level_start_index,\n            spatial_shapes=spatial_shapes,\n            valid_ratios=valid_ratios,\n            tgt_mask=tgt_mask\n        )\n\n        # new_hs = []\n        for i, output in enumerate(hs):\n            outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg)\n            outputs_class_whole = outputs_class\n            predictions_class.append(outputs_class_whole)\n            # predictions_class_part.append(outputs_class_part)\n            predictions_mask.append(outputs_mask)\n            # if iou_score is not None:\n            #     predictions_iou_score.append(iou_score)\n            #     new_hs.append(decoder_output_mask)\n        # if new_hs is not None:\n        #     hs = new_hs\n        # iteratively box prediction\n        out_boxes = self.pred_box_old(references, hs)\n\n        out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum()\n                                                               + self.mask_tokens.weight.sum() + self.lang_mapper.sum())\n        if mask_dict is not None:\n            if predictions_mask is None:\n                predictions_class[-1] = predictions_class[-1]\n                for i in range(self.mask_embed.num_layers):\n                    predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0])  # avoid no mask loss\n                predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n            if do_seg:\n                predictions_mask = list(predictions_mask)\n        elif self.training:  # this is to insure self.label_enc participate in the model\n            for i in range(self.mask_embed.num_layers):\n                predictions_class[-1] = predictions_class[-1] + 0.0 * (\n                            self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[\n                        0])  # avoid no mask loss\n            predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0]  # avoid no mask loss\n\n        out = {\n            'pred_logits': predictions_class[-1],\n            # 'pred_logits_part': predictions_class_part[-1],\n            'pred_masks': None if not do_seg else predictions_mask[-1],\n            'pred_boxes':out_boxes[-1],\n            # 'pred_ious': predictions_iou_score[-1],\n            'aux_outputs': self._set_aux_loss(\n                predictions_class if self.mask_classification else None, predictions_mask, out_boxes\n            )\n        }\n\n        return out, mask_dict\n\n    def forward_prediction_heads(self, output, mask_features, pred_mask=True):\n        decoder_output = self.decoder_norm(output)\n        decoder_output = decoder_output.transpose(0, 1)\n\n        class_embed = decoder_output @ self.class_embed\n        outputs_class = self.lang_encoder.compute_similarity(class_embed, name='whole')\n\n        outputs_mask = None\n        if pred_mask:\n            mask_embed = self.mask_embed(decoder_output)\n            outputs_mask = torch.einsum(\"bqc,bchw->bqhw\", mask_embed, mask_features)\n\n        return outputs_class, outputs_mask\n    \n    def idno_forward_prediction_heads(self, output, mask_features, pred_mask=True):\n        decoder_output = self.decoder_norm(output)\n        decoder_output = decoder_output.transpose(0, 1)\n        \n        decoder_output = decoder_output + 0.0 * (self.class_embed_part.sum() + self.class_embed.sum())\n\n        out = decoder_output.view(decoder_output.shape[0], -1, self.num_all_tokens, decoder_output.shape[-1])\n        decoder_output_mask = out[:, :, :self.num_mask_tokens].reshape(decoder_output.shape[0], -1, decoder_output.shape[-1])\n        # decoder_output_iou = out[:, :, -1].view(decoder_output.shape[0], -1, decoder_output.shape[-1])\n        decoder_output_iou = decoder_output_mask\n\n        outputs_mask = outputs_class_whole = outputs_class_part = None\n        if self.prediction_switch['whole']:\n            class_embed_whole = decoder_output @ self.class_embed\n            outputs_class_whole = self.lang_encoder.compute_similarity(class_embed_whole, name='whole')\n        if self.prediction_switch['part']:\n            class_embed_part = decoder_output @ self.class_embed_part\n            outputs_class_part = self.lang_encoder.compute_similarity(class_embed_part, name='part')\n        \n        outputs_class = (outputs_class_whole, outputs_class_part)\n        if self.prediction_switch['seg']:\n            mask_embed = self.mask_embed(decoder_output_mask)\n            if mask_embed.dtype==torch.float16 and mask_features.dtype==torch.float32:\n                mask_embed=mask_embed.to(torch.float32)\n            if mask_embed.dtype==torch.float32 and mask_features.dtype==torch.float16:\n                mask_features=mask_features.to(torch.float32)\n            outputs_mask = torch.einsum(\"bqc,bchw->bqhw\", mask_embed, mask_features.to(mask_embed.dtype))\n        iou_score = self.iou_prediction_head(decoder_output_iou).squeeze(-1).view(decoder_output.shape[0], -1, self.num_mask_tokens)\n        # outputs_mask = outputs_mask + 0.0 * iou_score.sum()  # TODO add iou prediction head\n\n        return outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output\n\n    @torch.jit.unused\n    def _set_aux_loss(self, outputs_class=None, outputs_seg_masks=None, out_boxes=None, predictions_iou_score=None, predictions_class_part=None):\n        # this is a workaround to make torchscript happy, as torchscript\n        # doesn't support dictionary with non-homogeneous values, such\n        # as a dict having both a Tensor and a list.\n        # if self.mask_classification:\n        if out_boxes is None:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b}\n                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])\n            ]\n        elif outputs_seg_masks is None:\n            return [\n                {\"pred_logits\": a, \"pred_boxes\": c}\n                for a, c in zip(outputs_class[:-1], out_boxes[:-1])\n            ]\n        elif predictions_iou_score is None:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b, \"pred_boxes\":c}\n                for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1])\n            ]\n        else:\n            return [\n                {\"pred_logits\": a, \"pred_masks\": b, \"pred_boxes\":c, \"pred_ious\":d, \"pred_logits_part\": e}\n                for a, b, c, d, e in zip(outputs_class[:-1], outputs_seg_masks[:-1],out_boxes[:-1], predictions_iou_score[:-1], predictions_class_part[:-1])\n            ]\n\n@register_decoder\ndef get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra):\n    return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra)\n"
  },
  {
    "path": "llava/model/semsam/body/decoder/modules.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\nfrom detectron2.layers import Conv2d\nimport fvcore.nn.weight_init as weight_init\n\n\nclass SelfAttentionLayer(nn.Module):\n\n    def __init__(self, d_model, nhead, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt,\n                     tgt_mask: Optional[Tensor] = None,\n                     tgt_key_padding_mask: Optional[Tensor] = None,\n                     query_pos: Optional[Tensor] = None):\n        q = k = self.with_pos_embed(tgt, query_pos)\n        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,\n                              key_padding_mask=tgt_key_padding_mask)[0]\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n\n        return tgt\n\n    def forward_pre(self, tgt,\n                    tgt_mask: Optional[Tensor] = None,\n                    tgt_key_padding_mask: Optional[Tensor] = None,\n                    query_pos: Optional[Tensor] = None):\n        tgt2 = self.norm(tgt)\n        q = k = self.with_pos_embed(tgt2, query_pos)\n        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,\n                              key_padding_mask=tgt_key_padding_mask)[0]\n        tgt = tgt + self.dropout(tgt2)\n        \n        return tgt\n\n    def forward(self, tgt,\n                tgt_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                query_pos: Optional[Tensor] = None):\n        if self.normalize_before:\n            return self.forward_pre(tgt, tgt_mask,\n                                    tgt_key_padding_mask, query_pos)\n        return self.forward_post(tgt, tgt_mask,\n                                 tgt_key_padding_mask, query_pos)\n\n\nclass CrossAttentionLayer(nn.Module):\n\n    def __init__(self, d_model, nhead, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt, memory,\n                     memory_mask: Optional[Tensor] = None,\n                     memory_key_padding_mask: Optional[Tensor] = None,\n                     pos: Optional[Tensor] = None,\n                     query_pos: Optional[Tensor] = None):\n        tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),\n                                   key=self.with_pos_embed(memory, pos),\n                                   value=memory, attn_mask=memory_mask,\n                                   key_padding_mask=memory_key_padding_mask)\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n        return tgt, avg_attn\n\n    def forward_pre(self, tgt, memory,\n                    memory_mask: Optional[Tensor] = None,\n                    memory_key_padding_mask: Optional[Tensor] = None,\n                    pos: Optional[Tensor] = None,\n                    query_pos: Optional[Tensor] = None):\n        tgt2 = self.norm(tgt)\n        tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),\n                                   key=self.with_pos_embed(memory, pos),\n                                   value=memory, attn_mask=memory_mask,\n                                   key_padding_mask=memory_key_padding_mask)\n        tgt = tgt + self.dropout(tgt2)\n\n        return tgt, avg_attn\n\n    def forward(self, tgt, memory,\n                memory_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None,\n                query_pos: Optional[Tensor] = None):\n        if self.normalize_before:\n            return self.forward_pre(tgt, memory, memory_mask,\n                                    memory_key_padding_mask, pos, query_pos)\n        return self.forward_post(tgt, memory, memory_mask,\n                                 memory_key_padding_mask, pos, query_pos)\n\n\nclass FFNLayer(nn.Module):\n\n    def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,\n                 activation=\"relu\", normalize_before=False):\n        super().__init__()\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm = nn.LayerNorm(d_model)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n        self._reset_parameters()\n    \n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(self, tgt):\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout(tgt2)\n        tgt = self.norm(tgt)\n        return tgt\n\n    def forward_pre(self, tgt):\n        tgt2 = self.norm(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n        tgt = tgt + self.dropout(tgt2)\n        return tgt\n\n    def forward(self, tgt):\n        if self.normalize_before:\n            return self.forward_pre(tgt)\n        return self.forward_post(tgt)\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\n\nclass MLP(nn.Module):\n    \"\"\" Very simple multi-layer perceptron (also called FFN)\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n\n    def forward(self, x):\n        for i, layer in enumerate(self.layers):\n            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        return x\n"
  },
  {
    "path": "llava/model/semsam/body/decoder/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_decoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/semsam/body/decoder/utils/__init__.py",
    "content": "from .utils import *"
  },
  {
    "path": "llava/model/semsam/body/decoder/utils/dino_decoder.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\n\nfrom typing import Optional, List, Union\nimport torch\nfrom torch import nn, Tensor\nfrom torch.cuda.amp import autocast\n\nfrom .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid\nfrom ...encoder.ops.modules import MSDeformAttn\nfrom torch.utils.checkpoint import checkpoint\n\n\nclass TransformerDecoder(nn.Module):\n\n    def __init__(self, decoder_layer, num_layers, norm=None,\n                 return_intermediate=False,\n                 d_model=256, query_dim=4,\n                 modulate_hw_attn=True,\n                 num_feature_levels=1,\n                 deformable_decoder=True,\n                 decoder_query_perturber=None,\n                 dec_layer_number=None,  # number of queries each layer in decoder\n                 rm_dec_query_scale=True,\n                 dec_layer_share=False,\n                 dec_layer_dropout_prob=None,\n                 task_switch=None,\n                 ):\n        super().__init__()\n        if num_layers > 0:\n            self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)\n        else:\n            self.layers = []\n        self.num_layers = num_layers\n        self.norm = norm\n        self.return_intermediate = return_intermediate\n        assert return_intermediate, \"support return_intermediate only\"\n        self.query_dim = query_dim\n        assert query_dim in [2, 4], \"query_dim should be 2/4 but {}\".format(query_dim)\n        self.num_feature_levels = num_feature_levels\n\n        self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)\n        if not deformable_decoder:\n            self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)\n        else:\n            self.query_pos_sine_scale = None\n\n        if rm_dec_query_scale:\n            self.query_scale = None\n        else:\n            raise NotImplementedError\n            self.query_scale = MLP(d_model, d_model, d_model, 2)\n        self.bbox_embed = None\n        self.class_embed = None\n\n        self.d_model = d_model\n        self.modulate_hw_attn = modulate_hw_attn\n        self.deformable_decoder = deformable_decoder\n\n        if not deformable_decoder and modulate_hw_attn:\n            self.ref_anchor_head = MLP(d_model, d_model, 2, 2)\n        else:\n            self.ref_anchor_head = None\n\n        self.decoder_query_perturber = decoder_query_perturber\n        self.box_pred_damping = None\n\n        self.dec_layer_number = dec_layer_number\n        if dec_layer_number is not None:\n            assert isinstance(dec_layer_number, list)\n            assert len(dec_layer_number) == num_layers\n            # assert dec_layer_number[0] ==\n\n        self.dec_layer_dropout_prob = dec_layer_dropout_prob\n        if dec_layer_dropout_prob is not None:\n            assert isinstance(dec_layer_dropout_prob, list)\n            assert len(dec_layer_dropout_prob) == num_layers\n            for i in dec_layer_dropout_prob:\n                assert 0.0 <= i <= 1.0\n\n        self.task_switch = task_switch\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n        for m in self.modules():\n            if isinstance(m, MSDeformAttn):\n                m._reset_parameters()\n\n    def forward(self, tgt, memory,\n                tgt_mask: Optional[Tensor] = None,\n                memory_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None,\n                pos: Optional[Tensor] = None,\n                refpoints_unsigmoid: Optional[Tensor] = None,  # num_queries, bs, 2\n                # for memory\n                level_start_index: Optional[Tensor] = None,  # num_levels\n                spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2\n                valid_ratios: Optional[Tensor] = None,\n                # misc\n                extra: Optional[Tensor] = {}, # extra information\n                ):\n        \"\"\"\n        Input:\n            - tgt: nq, bs, d_model\n            - memory: hw, bs, d_model\n            - pos: hw, bs, d_model\n            - refpoints_unsigmoid: nq, bs, 2/4\n            - valid_ratios/spatial_shapes: bs, nlevel, 2\n        \"\"\"\n        output = tgt\n\n        intermediate = []\n        reference_points = refpoints_unsigmoid.sigmoid()\n        ref_points = [reference_points]\n\n        if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys():\n            reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0)\n            output = torch.cat((output, extra['grounding_tokens']), dim=0)\n\n        for layer_id, layer in enumerate(self.layers):            \n            # preprocess ref points\n            if self.training and self.decoder_query_perturber is not None and layer_id != 0:\n                reference_points = self.decoder_query_perturber(reference_points)\n\n            reference_points_input = reference_points[:, :, None] \\\n                                         * torch.cat([valid_ratios, valid_ratios], -1)[None, :].to(reference_points.dtype)  # nq, bs, nlevel, 4\n            # print('reference_points_input', reference_points_input.dtype)\n            # print('memory', memory.dtype)\n            # reference_points_input=reference_points_input.to(memory.dtype)\n            query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2\n            # import pdb; pdb.set_trace()\n            # query_sine_embed = query_sine_embed.to(self.ref_point_head.layers[0].weight.dtype)\n            raw_query_pos = self.ref_point_head(query_sine_embed)  # nq, bs, 256\n\n            pos_scale = self.query_scale(output) if self.query_scale is not None else 1\n            query_pos = pos_scale * raw_query_pos\n            output = layer(\n                        output,\n                        query_pos,\n                        query_sine_embed,\n                        tgt_key_padding_mask,\n                        reference_points_input, memory,\n                        memory_key_padding_mask,\n                        level_start_index,\n                        spatial_shapes,\n                        pos,\n                        tgt_mask,\n                        memory_mask,\n                        self.task_switch,\n                        extra,\n                        )\n\n            # grounding language token reference point will not update and saved\n            if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n                _reference_points = reference_points[-extra['grounding_len']:]\n                reference_points = reference_points[:-extra['grounding_len']]\n                _output = output[-extra['grounding_len']:]\n                output = output[:-extra['grounding_len']]\n\n            # iter update\n            if self.bbox_embed is not None:\n                reference_before_sigmoid = inverse_sigmoid(reference_points)\n                # import pdb; pdb.set_trace()\n                output= output.to(query_sine_embed.dtype)\n                delta_unsig = self.bbox_embed[layer_id](output)\n                outputs_unsig = delta_unsig + reference_before_sigmoid\n                new_reference_points = outputs_unsig.sigmoid()\n\n                reference_points = new_reference_points.detach()\n                # if layer_id != self.num_layers - 1:\n                ref_points.append(new_reference_points)\n\n            intermediate.append(self.norm(output))\n\n            # add back grounding language token\n            if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n                reference_points = torch.cat((reference_points, _reference_points))\n                output = torch.cat((output, _output))\n\n        return [\n            [itm_out.transpose(0, 1) for itm_out in intermediate],\n            [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]\n        ]\n\n\nclass DeformableTransformerDecoderLayer(nn.Module):\n\n    def __init__(self, d_model=256, d_ffn=1024,\n                 dropout=0.1, activation=\"relu\",\n                 n_levels=4, n_heads=8, n_points=4,\n                 use_deformable_box_attn=False,\n                 key_aware_type=None,\n                 ):\n        super().__init__()\n\n        # cross attention\n        if use_deformable_box_attn:\n            raise NotImplementedError\n        else:\n            self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n        self.dropout1 = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n\n        # self attention\n        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.norm2 = nn.LayerNorm(d_model)\n\n        # ffn\n        self.linear1 = nn.Linear(d_model, d_ffn)\n        self.activation = _get_activation_fn(activation)\n        self.dropout3 = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(d_ffn, d_model)\n        self.dropout4 = nn.Dropout(dropout)\n        self.norm3 = nn.LayerNorm(d_model)\n\n        self.key_aware_type = key_aware_type\n        self.key_aware_proj = None\n\n    def rm_self_attn_modules(self):\n        self.self_attn = None\n        self.dropout2 = None\n        self.norm2 = None\n\n    @staticmethod\n    def with_pos_embed(tensor, pos):\n        return tensor if pos is None else tensor + pos\n\n    def forward_ffn(self, tgt):\n        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout4(tgt2)\n        tgt = self.norm3(tgt)\n        return tgt\n\n    @autocast(enabled=True)\n    def forward(self,\n                # for tgt\n                tgt: Optional[Tensor],  # nq, bs, d_model\n                tgt_query_pos: Optional[Tensor] = None,  # pos for query. MLP(Sine(pos))\n                tgt_query_sine_embed: Optional[Tensor] = None,  # pos for query. Sine(pos)\n                tgt_key_padding_mask: Optional[Tensor] = None,\n                tgt_reference_points: Optional[Tensor] = None,  # nq, bs, 4\n\n                # for memory\n                memory: Optional[Tensor] = None,  # hw, bs, d_model\n                memory_key_padding_mask: Optional[Tensor] = None,\n                memory_level_start_index: Optional[Tensor] = None,  # num_levels\n                memory_spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2\n                memory_pos: Optional[Tensor] = None,  # pos for memory\n\n                # sa\n                self_attn_mask: Optional[Tensor] = None,  # mask used for self-attention\n                cross_attn_mask: Optional[Tensor] = None,  # mask used for cross-attention\n\n                # misc\n                task_switch: Optional[Tensor] = {}, # extra information                \n                extra: Optional[Tensor] = {}, # extra information\n                ):\n        \"\"\"\n        Input:\n            - tgt/tgt_query_pos: nq, bs, d_model\n            -\n        \"\"\"\n        # self attention\n        # import pdb;pdb.set_trace()\n        if self.self_attn is not None:\n            q = k = self.with_pos_embed(tgt, tgt_query_pos)\n            tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]\n            tgt = tgt + self.dropout2(tgt2)\n            tgt = self.norm2(tgt)\n\n        # exclude grounding token for cross attention\n        if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n            _grounding_lang_tokens = tgt[-extra['grounding_len']:,]\n            _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,]\n            _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,]\n            tgt = tgt[:-extra['grounding_len'],]\n            tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],]\n            tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],]\n\n        # cross attention\n        if self.key_aware_type is not None:\n            if self.key_aware_type == 'mean':\n                tgt = tgt + memory.mean(0, keepdim=True)\n            elif self.key_aware_type == 'proj_mean':\n                tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)\n            else:\n                raise NotImplementedError(\"Unknown key_aware_type: {}\".format(self.key_aware_type))\n        # import pdb;pdb.set_trace()\n\n        tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),\n                               tgt_reference_points.transpose(0, 1).contiguous(),\n                               memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index,\n                               memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better\n        # import pdb;pdb.set_trace()\n\n        tgt = tgt + self.dropout1(tgt2)\n\n        # add back grounding token for self attention\n        if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg':\n            tgt = torch.cat((tgt, _grounding_lang_tokens))\n\n        tgt = self.norm1(tgt)\n        tgt = self.forward_ffn(tgt) # ffn\n        return tgt"
  },
  {
    "path": "llava/model/semsam/body/decoder/utils/utils.py",
    "content": "import torch\nimport copy\nfrom torch import nn, Tensor\nimport os\n\nimport math\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass MLP(nn.Module):\n    \"\"\" Very simple multi-layer perceptron (also called FFN)\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n\n    def forward(self, x):\n        for i, layer in enumerate(self.layers):\n            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        return x\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1/x2)\n\n\ndef gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor):\n    \"\"\"\n    Input:\n        - memory: bs, \\sum{hw}, d_model\n        - memory_padding_mask: bs, \\sum{hw}\n        - spatial_shapes: nlevel, 2\n    Output:\n        - output_memory: bs, \\sum{hw}, d_model\n        - output_proposals: bs, \\sum{hw}, 4\n    \"\"\"\n    N_, S_, C_ = memory.shape\n    base_scale = 4.0\n    proposals = []\n    _cur = 0\n    for lvl, (H_, W_) in enumerate(spatial_shapes):\n        mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)\n        valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)\n        valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)\n\n        grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),\n                                        torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))\n        grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)\n\n        scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)\n        grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale\n        wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)\n        proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)\n        proposals.append(proposal)\n        _cur += (H_ * W_)\n    output_proposals = torch.cat(proposals, 1)\n    output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)\n    output_proposals = torch.log(output_proposals / (1 - output_proposals))\n    output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))\n    output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))\n\n    output_memory = memory\n    output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))\n    output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))\n    return output_memory, output_proposals\n\n\ndef gen_sineembed_for_position(pos_tensor, dim=128):\n    # n_query, bs, _ = pos_tensor.size()\n    # sineembed_tensor = torch.zeros(n_query, bs, 256)\n    scale = 2 * math.pi\n    dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)\n    dim_t = 10000 ** (2 * (dim_t // 2) / dim)\n    x_embed = pos_tensor[:, :, 0] * scale\n    y_embed = pos_tensor[:, :, 1] * scale\n    pos_x = x_embed[:, :, None] / dim_t\n    pos_y = y_embed[:, :, None] / dim_t\n    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)\n    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)\n    if pos_tensor.size(-1) == 2:\n        pos = torch.cat((pos_y, pos_x), dim=2)\n    elif pos_tensor.size(-1) == 4:\n        w_embed = pos_tensor[:, :, 2] * scale\n        pos_w = w_embed[:, :, None] / dim_t\n        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)\n\n        h_embed = pos_tensor[:, :, 3] * scale\n        pos_h = h_embed[:, :, None] / dim_t\n        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)\n\n        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)\n    else:\n        raise ValueError(\"Unknown pos_tensor shape(-1):{}\".format(pos_tensor.size(-1)))\n    return pos.to(pos_tensor.dtype)\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    if activation == \"prelu\":\n        return nn.PReLU()\n    if activation == \"selu\":\n        return F.selu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\n\ndef _get_clones(module, N, layer_share=False):\n\n    if layer_share:\n        return nn.ModuleList([module for i in range(N)])\n    else:\n        return nn.ModuleList([copy.deepcopy(module) for i in range(N)])"
  },
  {
    "path": "llava/model/semsam/body/encoder/__init__.py",
    "content": "from .build import build_encoder"
  },
  {
    "path": "llava/model/semsam/body/encoder/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\nfrom .transformer_encoder_fpn import *\nfrom .encoder_deform import *\n\ndef build_encoder(config, *args, **kwargs):\n    model_name = config['MODEL']['ENCODER']['NAME']\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, *args, **kwargs)"
  },
  {
    "path": "llava/model/semsam/body/encoder/encoder_deform.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified by Feng Li and Hao Zhang.\nimport logging\nimport numpy as np\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\nimport fvcore.nn.weight_init as weight_init\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import xavier_uniform_, constant_, uniform_, normal_\nfrom torch.cuda.amp import autocast\n\nfrom detectron2.layers import Conv2d, ShapeSpec, get_norm\n# from detectron2.modeling import SEM_SEG_HEADS_REGISTRY\n\nfrom .registry import register_encoder\nfrom ...utils import configurable\nfrom ...modules import PositionEmbeddingSine\nfrom ..transformer_blocks import _get_clones, _get_activation_fn\nfrom .ops.modules import MSDeformAttn\nfrom torch.utils import checkpoint\n\n# MSDeformAttn Transformer encoder in deformable detr\nclass MSDeformAttnTransformerEncoderOnly(nn.Module):\n    def __init__(self, d_model=256, nhead=8,\n                 num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,\n                 activation=\"relu\",\n                 num_feature_levels=4, enc_n_points=4,):\n        super().__init__()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n        encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,\n                                                            dropout, activation,\n                                                            num_feature_levels, nhead, enc_n_points)\n        self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)\n\n        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n        for m in self.modules():\n            if isinstance(m, MSDeformAttn):\n                m._reset_parameters()\n        normal_(self.level_embed)\n\n    def get_valid_ratio(self, mask):\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def forward(self, srcs, masks, pos_embeds, use_ckpt=False):\n\n        enable_mask=0\n        if masks is not None:\n            for src in srcs:\n                if src.size(2)%32 or src.size(3)%32:\n                    enable_mask = 1\n        if enable_mask==0:\n            masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]\n        # prepare input for encoder\n        src_flatten = []\n        mask_flatten = []\n        lvl_pos_embed_flatten = []\n        spatial_shapes = []\n        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):\n            bs, c, h, w = src.shape\n            spatial_shape = (h, w)\n            spatial_shapes.append(spatial_shape)\n            src = src.flatten(2).transpose(1, 2)\n            mask = mask.flatten(1)\n            pos_embed = pos_embed.flatten(2).transpose(1, 2)\n            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)\n            lvl_pos_embed_flatten.append(lvl_pos_embed)\n            src_flatten.append(src)\n            mask_flatten.append(mask)\n        src_flatten = torch.cat(src_flatten, 1)\n        mask_flatten = torch.cat(mask_flatten, 1)\n        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        # encoder\n        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt)\n        return memory, spatial_shapes, level_start_index\n\n\nclass MSDeformAttnTransformerEncoderLayer(nn.Module):\n    def __init__(self,\n                 d_model=256, d_ffn=1024,\n                 dropout=0.1, activation=\"relu\",\n                 n_levels=4, n_heads=8, n_points=4):\n        super().__init__()\n\n        # self attention\n        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n        self.dropout1 = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n\n        # ffn\n        self.linear1 = nn.Linear(d_model, d_ffn)\n        self.activation = _get_activation_fn(activation)\n        self.dropout2 = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(d_ffn, d_model)\n        self.dropout3 = nn.Dropout(dropout)\n        self.norm2 = nn.LayerNorm(d_model)\n\n    @staticmethod\n    def with_pos_embed(tensor, pos):\n        return tensor if pos is None else tensor + pos\n\n    def forward_ffn(self, src):\n        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n        src = src + self.dropout3(src2)\n        src = self.norm2(src)\n        return src\n\n    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):\n        # self attention\n        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n\n        # ffn\n        src = self.forward_ffn(src)\n\n        return src\n\n\nclass MSDeformAttnTransformerEncoder(nn.Module):\n    def __init__(self, encoder_layer, num_layers):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n\n    @staticmethod\n    def get_reference_points(spatial_shapes, valid_ratios, device):\n        reference_points_list = []\n        for lvl, (H_, W_) in enumerate(spatial_shapes):\n\n            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),\n                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))\n            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)\n            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)\n            ref = torch.stack((ref_x, ref_y), -1)\n            reference_points_list.append(ref)\n        reference_points = torch.cat(reference_points_list, 1)\n        reference_points = reference_points[:, :, None] * valid_ratios[:, None]\n        return reference_points\n\n    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False):\n        output = src\n        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)\n        for _, layer in enumerate(self.layers):\n            use_ckpt = False\n            if use_ckpt:\n                output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)\n            else:\n                output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)\n\n        return output\n\n\nclass MaskDINOEncoder(nn.Module):\n    \"\"\"\n    This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models.\n    \"\"\"\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        transformer_dropout: float,\n        transformer_nheads: int,\n        transformer_dim_feedforward: int,\n        transformer_enc_layers: int,\n        conv_dim: int,\n        mask_dim: int,\n        norm: Optional[Union[str, Callable]] = None,\n        # deformable transformer encoder args\n        transformer_in_features: List[str],\n        common_stride: int,\n        num_feature_levels: int,\n        total_num_feature_levels: int,\n        feature_order: str,\n        use_ckpt=False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            transformer_dropout: dropout probability in transformer\n            transformer_nheads: number of heads in transformer\n            transformer_dim_feedforward: dimension of feedforward network\n            transformer_enc_layers: number of transformer encoder layers\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n            num_feature_levels: feature scales used\n            total_num_feature_levels: total feautre scales used (include the downsampled features)\n            feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first.\n        \"\"\"\n        super().__init__()\n        self.use_ckpt = use_ckpt\n        transformer_input_shape = {\n            k: v for k, v in input_shape.items() if k in transformer_in_features\n        }\n        # this is the input shape of pixel decoder\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        self.feature_strides = [v.stride for k, v in input_shape]\n        self.feature_channels = [v.channels for k, v in input_shape]\n        self.feature_order = feature_order\n\n        if feature_order == \"low2high\":\n            transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride)\n        else:\n            transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)\n        self.transformer_in_features = [k for k, v in transformer_input_shape]  # starting from \"res2\" to \"res5\"\n        transformer_in_channels = [v.channels for k, v in transformer_input_shape]\n        self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape]  # to decide extra FPN layers\n\n        self.maskdino_num_feature_levels = num_feature_levels  # always use 3 scales\n        self.total_num_feature_levels = total_num_feature_levels\n        self.common_stride = common_stride\n\n        self.transformer_num_feature_levels = len(self.transformer_in_features)\n        self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels))\n        self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1\n        if self.transformer_num_feature_levels > 1:\n            input_proj_list = []\n            for in_channels in transformer_in_channels[::-1]:\n                input_proj_list.append(nn.Sequential(\n                    nn.Conv2d(in_channels, conv_dim, kernel_size=1),\n                    nn.GroupNorm(32, conv_dim),\n                ))\n            # input projectino for downsample\n            in_channels = max(transformer_in_channels)\n            for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels):  # exclude the res2\n                input_proj_list.append(nn.Sequential(\n                    nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1),\n                    nn.GroupNorm(32, conv_dim),\n                ))\n                in_channels = conv_dim\n            self.input_proj = nn.ModuleList(input_proj_list)\n        else:\n            self.input_proj = nn.ModuleList([\n                nn.Sequential(\n                    nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),\n                    nn.GroupNorm(32, conv_dim),\n                )])\n\n        for proj in self.input_proj:\n            nn.init.xavier_uniform_(proj[0].weight, gain=1)\n            nn.init.constant_(proj[0].bias, 0)\n\n        self.transformer = MSDeformAttnTransformerEncoderOnly(\n            d_model=conv_dim,\n            dropout=transformer_dropout,\n            nhead=transformer_nheads,\n            dim_feedforward=transformer_dim_feedforward,\n            num_encoder_layers=transformer_enc_layers,\n            num_feature_levels=self.total_num_feature_levels,\n        )\n        N_steps = conv_dim // 2\n        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)\n\n        self.mask_dim = mask_dim\n        # use 1x1 conv instead\n        self.mask_features = Conv2d(\n            conv_dim,\n            mask_dim,\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        )\n        weight_init.c2_xavier_fill(self.mask_features)\n        # extra fpn levels\n        stride = min(self.transformer_feature_strides)\n        self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1)\n\n        lateral_convs = []\n        output_convs = []\n\n        use_bias = norm == \"\"\n        for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):\n            lateral_norm = get_norm(norm, conv_dim)\n            output_norm = get_norm(norm, conv_dim)\n\n            lateral_conv = Conv2d(\n                in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm\n            )\n            output_conv = Conv2d(\n                conv_dim,\n                conv_dim,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n                bias=use_bias,\n                norm=output_norm,\n                activation=F.relu,\n            )\n            weight_init.c2_xavier_fill(lateral_conv)\n            weight_init.c2_xavier_fill(output_conv)\n            self.add_module(\"adapter_{}\".format(idx + 1), lateral_conv)\n            self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n            lateral_convs.append(lateral_conv)\n            output_convs.append(output_conv)\n        # Place convs into top-down order (from low to high resolution)\n        # to make the top-down computation in forward clearer.\n        self.lateral_convs = lateral_convs[::-1]\n        self.output_convs = output_convs[::-1]\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret = {}\n        ret[\"input_shape\"] = {\n            k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n        }\n        ret[\"conv_dim\"] = enc_cfg['CONVS_DIM']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"norm\"] = enc_cfg['NORM']\n        ret[\"transformer_dropout\"] = dec_cfg['DROPOUT']\n        ret[\"transformer_nheads\"] = dec_cfg['NHEADS']\n        ret[\"transformer_dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']  # deformable transformer encoder\n        ret[\n            \"transformer_enc_layers\"\n        ] = enc_cfg['TRANSFORMER_ENC_LAYERS']  # a separate config\n        ret[\"transformer_in_features\"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES']  # ['res3', 'res4', 'res5']\n        ret[\"common_stride\"] = enc_cfg['COMMON_STRIDE']\n        ret[\"total_num_feature_levels\"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS']\n        ret[\"num_feature_levels\"] = enc_cfg['NUM_FEATURE_LEVELS']\n        ret[\"feature_order\"] = enc_cfg['FEATURE_ORDER']\n        ret[\"use_ckpt\"] = enc_cfg.get('USE_CKPT', False)\n        return ret\n\n    @autocast(enabled=True)\n    def forward_features(self, features, masks):\n        \"\"\"\n        :param features: multi-scale features from the backbone\n        :param masks: image mask\n        :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask\n        \"\"\"\n        # backbone features\n        srcs = []\n        pos = []\n        # additional downsampled features\n        srcsl = []\n        posl = []\n        # import pdb; pdb.set_trace()\n        if self.total_num_feature_levels > self.transformer_num_feature_levels:\n            smallest_feat = features[self.transformer_in_features[self.low_resolution_index]]#.float()\n            _len_srcs = self.transformer_num_feature_levels\n            for l in range(_len_srcs, self.total_num_feature_levels):\n                if l == _len_srcs:\n                    src = self.input_proj[l](smallest_feat)\n                else:\n                    src = self.input_proj[l](srcsl[-1])\n                srcsl.append(src)\n                posl.append(self.pe_layer(src))\n        srcsl = srcsl[::-1]\n        # Reverse feature maps\n        for idx, f in enumerate(self.transformer_in_features[::-1]):\n            x = features[f]#.float()  # deformable detr does not support half precision\n            srcs.append(self.input_proj[idx](x))\n            pos.append(self.pe_layer(x))\n        srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs)\n        pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos)\n        if self.feature_order != 'low2high':\n            srcs = srcsl\n            pos = posl\n        y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt)\n        bs = y.shape[0]\n\n        split_size_or_sections = [None] * self.total_num_feature_levels\n        for i in range(self.total_num_feature_levels):\n            if i < self.total_num_feature_levels - 1:\n                split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]\n            else:\n                split_size_or_sections[i] = y.shape[1] - level_start_index[i]\n        y = torch.split(y, split_size_or_sections, dim=1)\n\n        out = []\n        multi_scale_features = []\n        num_cur_levels = 0\n        for i, z in enumerate(y):\n            out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))\n\n        # append `out` with extra FPN levels\n        # Reverse feature maps into top-down order (from low to high resolution)\n        for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):\n            x = features[f]#.float()\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            cur_fpn = lateral_conv(x)\n            # Following FPN implementation, we use nearest upsampling here\n            y = cur_fpn + F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode=\"bilinear\", align_corners=False)\n            y = output_conv(y)\n            out.append(y)\n        for o in out:\n            if num_cur_levels < self.total_num_feature_levels:\n                multi_scale_features.append(o)\n                num_cur_levels += 1\n        return self.mask_features(out[-1]), out[0], multi_scale_features\n\n\n\n@register_encoder\ndef get_maskdino_encoder_deform(cfg, input_shape):\n    \"\"\"\n    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.\n    \"\"\"\n    model = MaskDINOEncoder(cfg, input_shape)\n    forward_features = getattr(model, \"forward_features\", None)\n    if not callable(forward_features):\n        raise ValueError(\n            \"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. \"\n            f\"Please implement forward_features for {name} to only return mask features.\"\n        )\n    return model"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/functions/__init__.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom .ms_deform_attn_func import MSDeformAttnFunction\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/functions/ms_deform_attn_func.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\n\ntry:\n    import MultiScaleDeformableAttention as MSDA\nexcept ModuleNotFoundError as e:\n    info_string = (\n        \"\\n\\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\\n\"\n        \"\\t`cd mask2former/modeling/pixel_decoder/ops`\\n\"\n        \"\\t`sh make.sh`\\n\"\n    )\n    raise ModuleNotFoundError(info_string)\n\n\nclass MSDeformAttnFunction(Function):\n    @staticmethod\n    def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):\n        ctx.im2col_step = im2col_step\n        output = MSDA.ms_deform_attn_forward(\n            value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)\n        ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors\n        grad_value, grad_sampling_loc, grad_attn_weight = \\\n            MSDA.ms_deform_attn_backward(\n                value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)\n\n        return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None\n\n\ndef ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):\n    # for debug and test only,\n    # need to use cuda version instead\n    N_, S_, M_, D_ = value.shape\n    _, Lq_, M_, L_, P_, _ = sampling_locations.shape\n    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)\n    sampling_grids = 2 * sampling_locations - 1\n    sampling_value_list = []\n    for lid_, (H_, W_) in enumerate(value_spatial_shapes):\n        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_\n        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)\n        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2\n        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)\n        # N_*M_, D_, Lq_, P_\n        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,\n                                          mode='bilinear', padding_mode='zeros', align_corners=False)\n        sampling_value_list.append(sampling_value_l_)\n    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)\n    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)\n    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)\n    return output.transpose(1, 2).contiguous()\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/make.sh",
    "content": "#!/usr/bin/env bash\n# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\npython setup.py build install --user\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/modules/__init__.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom .ms_deform_attn import MSDeformAttn\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/modules/ms_deform_attn.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport warnings\nimport math\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.nn.init import xavier_uniform_, constant_\n\nfrom ..functions import MSDeformAttnFunction\nfrom ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch\n\n\ndef _is_power_of_2(n):\n    if (not isinstance(n, int)) or (n < 0):\n        raise ValueError(\"invalid input for _is_power_of_2: {} (type: {})\".format(n, type(n)))\n    return (n & (n-1) == 0) and n != 0\n\n\nclass MSDeformAttn(nn.Module):\n    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):\n        \"\"\"\n        Multi-Scale Deformable Attention Module\n        :param d_model      hidden dimension\n        :param n_levels     number of feature levels\n        :param n_heads      number of attention heads\n        :param n_points     number of sampling points per attention head per feature level\n        \"\"\"\n        super().__init__()\n        if d_model % n_heads != 0:\n            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))\n        _d_per_head = d_model // n_heads\n        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation\n        if not _is_power_of_2(_d_per_head):\n            warnings.warn(\"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 \"\n                          \"which is more efficient in our CUDA implementation.\")\n\n        self.im2col_step = 128\n\n        self.d_model = d_model\n        self.n_levels = n_levels\n        self.n_heads = n_heads\n        self.n_points = n_points\n\n        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)\n        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)\n        self.value_proj = nn.Linear(d_model, d_model)\n        self.output_proj = nn.Linear(d_model, d_model)\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        constant_(self.sampling_offsets.weight.data, 0.)\n        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)\n        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)\n        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)\n        for i in range(self.n_points):\n            grid_init[:, :, i, :] *= i + 1\n        with torch.no_grad():\n            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))\n        constant_(self.attention_weights.weight.data, 0.)\n        constant_(self.attention_weights.bias.data, 0.)\n        xavier_uniform_(self.value_proj.weight.data)\n        constant_(self.value_proj.bias.data, 0.)\n        xavier_uniform_(self.output_proj.weight.data)\n        constant_(self.output_proj.bias.data, 0.)\n\n    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):\n        \"\"\"\n        :param query                       (N, Length_{query}, C)\n        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area\n                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes\n        :param input_flatten               (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)\n        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]\n        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]\n        :param input_padding_mask          (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements\n\n        :return output                     (N, Length_{query}, C)\n        \"\"\"\n        N, Len_q, _ = query.shape\n        N, Len_in, _ = input_flatten.shape\n        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in\n        # input_flatten=input_flatten.to(self.value_proj.bias.data.dtype)\n\n        value = self.value_proj(input_flatten)\n        if input_padding_mask is not None:\n            value = value.masked_fill(input_padding_mask[..., None], float(0))\n        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)\n        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)\n        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)\n        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)\n        # N, Len_q, n_heads, n_levels, n_points, 2\n        if reference_points.shape[-1] == 2:\n            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)\n            sampling_locations = reference_points[:, :, None, :, None, :] \\\n                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]\n        elif reference_points.shape[-1] == 4:\n            sampling_locations = reference_points[:, :, None, :, None, :2] \\\n                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5\n        else:\n            raise ValueError(\n                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))\n        # try:\n        # print(value.dtype)\n        convert=False\n        # import pdb; pdb.set_trace()\n        dtype=value.dtype\n        if value.dtype== torch.bfloat16 or value.dtype== torch.float16:\n            value = value.float()\n            attention_weights = attention_weights.float()\n            sampling_locations = sampling_locations.float()\n            convert=True\n        # value = value.float()\n        # attention_weights = attention_weights.float()\n        # sampling_locations = sampling_locations.float()\n        # convert=True\n        output = MSDeformAttnFunction.apply(\n            value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)\n        if convert:\n            output = output.to(dtype)\n        # except:\n        #     # CPU\n        #     output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)\n        # # For FLOPs calculation only\n        # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)\n        output = self.output_proj(output)\n        return output\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/setup.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nimport os\nimport glob\n\nimport torch\n\nfrom torch.utils.cpp_extension import CUDA_HOME\nfrom torch.utils.cpp_extension import CppExtension\nfrom torch.utils.cpp_extension import CUDAExtension\n\nfrom setuptools import find_packages\nfrom setuptools import setup\n\nrequirements = [\"torch\", \"torchvision\"]\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    extensions_dir = os.path.join(this_dir, \"src\")\n\n    main_file = glob.glob(os.path.join(extensions_dir, \"*.cpp\"))\n    source_cpu = glob.glob(os.path.join(extensions_dir, \"cpu\", \"*.cpp\"))\n    source_cuda = glob.glob(os.path.join(extensions_dir, \"cuda\", \"*.cu\"))\n\n    sources = main_file + source_cpu\n    extension = CppExtension\n    extra_compile_args = {\"cxx\": []}\n    define_macros = []\n\n    # Force cuda since torch ask for a device, not if cuda is in fact available.\n    if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args[\"nvcc\"] = [\n            \"-DCUDA_HAS_FP16=1\",\n            \"-D__CUDA_NO_HALF_OPERATORS__\",\n            \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n            \"-D__CUDA_NO_HALF2_OPERATORS__\",\n        ]\n    else:\n        if CUDA_HOME is None:\n            raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')\n        else:\n            raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')\n\n    sources = [os.path.join(extensions_dir, s) for s in sources]\n    include_dirs = [extensions_dir]\n    ext_modules = [\n        extension(\n            \"MultiScaleDeformableAttention\",\n            sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n        )\n    ]\n    return ext_modules\n\nsetup(\n    name=\"MultiScaleDeformableAttention\",\n    version=\"1.0\",\n    author=\"Weijie Su\",\n    url=\"https://github.com/fundamentalvision/Deformable-DETR\",\n    description=\"PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention\",\n    packages=find_packages(exclude=(\"configs\", \"tests\",)),\n    ext_modules=get_extensions(),\n    cmdclass={\"build_ext\": torch.utils.cpp_extension.BuildExtension},\n)\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <vector>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n\nat::Tensor\nms_deform_attn_cpu_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    AT_ERROR(\"Not implement on cpu\");\n}\n\nstd::vector<at::Tensor>\nms_deform_attn_cpu_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n    AT_ERROR(\"Not implement on cpu\");\n}\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n#include <torch/extension.h>\n\nat::Tensor\nms_deform_attn_cpu_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step);\n\nstd::vector<at::Tensor>\nms_deform_attn_cpu_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step);\n\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <vector>\n#include \"cuda/ms_deform_im2col_cuda.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n\nat::Tensor ms_deform_attn_cuda_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    AT_ASSERTM(value.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(spatial_shapes.is_contiguous(), \"spatial_shapes tensor has to be contiguous\");\n    AT_ASSERTM(level_start_index.is_contiguous(), \"level_start_index tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n\n    AT_ASSERTM(value.type().is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(spatial_shapes.type().is_cuda(), \"spatial_shapes must be a CUDA tensor\");\n    AT_ASSERTM(level_start_index.type().is_cuda(), \"level_start_index must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.type().is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.type().is_cuda(), \"attn_weight must be a CUDA tensor\");\n\n    const int batch = value.size(0);\n    const int spatial_size = value.size(1);\n    const int num_heads = value.size(2);\n    const int channels = value.size(3);\n\n    const int num_levels = spatial_shapes.size(0);\n\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(4);\n\n    const int im2col_step_ = std::min(batch, im2col_step);\n\n    AT_ASSERTM(batch % im2col_step_ == 0, \"batch(%d) must divide im2col_step(%d)\", batch, im2col_step_);\n    \n    auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());\n\n    const int batch_n = im2col_step_;\n    auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});\n    auto per_value_size = spatial_size * num_heads * channels;\n    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;\n    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;\n    for (int n = 0; n < batch/im2col_step_; ++n)\n    {\n        auto columns = output_n.select(0, n);\n        AT_DISPATCH_FLOATING_TYPES(value.type(), \"ms_deform_attn_forward_cuda\", ([&] {\n            ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),\n                value.data<scalar_t>() + n * im2col_step_ * per_value_size,\n                spatial_shapes.data<int64_t>(),\n                level_start_index.data<int64_t>(),\n                sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,\n                batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,\n                columns.data<scalar_t>());\n\n        }));\n    }\n\n    output = output.view({batch, num_query, num_heads*channels});\n\n    return output;\n}\n\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n\n    AT_ASSERTM(value.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(spatial_shapes.is_contiguous(), \"spatial_shapes tensor has to be contiguous\");\n    AT_ASSERTM(level_start_index.is_contiguous(), \"level_start_index tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n    AT_ASSERTM(grad_output.is_contiguous(), \"grad_output tensor has to be contiguous\");\n\n    AT_ASSERTM(value.type().is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(spatial_shapes.type().is_cuda(), \"spatial_shapes must be a CUDA tensor\");\n    AT_ASSERTM(level_start_index.type().is_cuda(), \"level_start_index must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.type().is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.type().is_cuda(), \"attn_weight must be a CUDA tensor\");\n    AT_ASSERTM(grad_output.type().is_cuda(), \"grad_output must be a CUDA tensor\");\n\n    const int batch = value.size(0);\n    const int spatial_size = value.size(1);\n    const int num_heads = value.size(2);\n    const int channels = value.size(3);\n\n    const int num_levels = spatial_shapes.size(0);\n\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(4);\n\n    const int im2col_step_ = std::min(batch, im2col_step);\n\n    AT_ASSERTM(batch % im2col_step_ == 0, \"batch(%d) must divide im2col_step(%d)\", batch, im2col_step_);\n\n    auto grad_value = at::zeros_like(value);\n    auto grad_sampling_loc = at::zeros_like(sampling_loc);\n    auto grad_attn_weight = at::zeros_like(attn_weight);\n\n    const int batch_n = im2col_step_;\n    auto per_value_size = spatial_size * num_heads * channels;\n    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;\n    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;\n    auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});\n    \n    for (int n = 0; n < batch/im2col_step_; ++n)\n    {\n        auto grad_output_g = grad_output_n.select(0, n);\n        AT_DISPATCH_FLOATING_TYPES(value.type(), \"ms_deform_attn_backward_cuda\", ([&] {\n            ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),\n                                    grad_output_g.data<scalar_t>(),\n                                    value.data<scalar_t>() + n * im2col_step_ * per_value_size,\n                                    spatial_shapes.data<int64_t>(),\n                                    level_start_index.data<int64_t>(),\n                                    sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                                    attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,\n                                    batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,\n                                    grad_value.data<scalar_t>() +  n * im2col_step_ * per_value_size,\n                                    grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,\n                                    grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);\n\n        }));\n    }\n\n    return {\n        grad_value, grad_sampling_loc, grad_attn_weight\n    };\n}"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n#include <torch/extension.h>\n\nat::Tensor ms_deform_attn_cuda_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step);\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step);\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh",
    "content": "/*!\n**************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************\n* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)\n* Copyright (c) 2018 Microsoft\n**************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THCAtomics.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N, const int num_threads)\n{\n  return (N + num_threads - 1) / num_threads;\n}\n\n\ntemplate <typename scalar_t>\n__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n  }\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\n\ntemplate <typename scalar_t>\n__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,\n                                                   const scalar_t &top_grad,\n                                                   const scalar_t &attn_weight,\n                                                   scalar_t* &grad_value, \n                                                   scalar_t* grad_sampling_loc,\n                                                   scalar_t* grad_attn_weight)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n  const scalar_t top_grad_value = top_grad * attn_weight;\n  scalar_t grad_h_weight = 0, grad_w_weight = 0;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n    grad_h_weight -= hw * v1;\n    grad_w_weight -= hh * v1;\n    atomicAdd(grad_value+ptr1, w1*top_grad_value);\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n    grad_h_weight -= lw * v2;\n    grad_w_weight += hh * v2;\n    atomicAdd(grad_value+ptr2, w2*top_grad_value);\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n    grad_h_weight += hw * v3;\n    grad_w_weight -= lh * v3;\n    atomicAdd(grad_value+ptr3, w3*top_grad_value); \n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n    grad_h_weight += lw * v4;\n    grad_w_weight += lh * v4;\n    atomicAdd(grad_value+ptr4, w4*top_grad_value);\n  }\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  *grad_attn_weight = top_grad * val;\n  *grad_sampling_loc = width * grad_w_weight * top_grad_value;\n  *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;\n}\n\n\ntemplate <typename scalar_t>\n__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, \n                                                   const int &height, const int &width, const int &nheads, const int &channels,\n                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,\n                                                   const scalar_t &top_grad,\n                                                   const scalar_t &attn_weight,\n                                                   scalar_t* &grad_value, \n                                                   scalar_t* grad_sampling_loc,\n                                                   scalar_t* grad_attn_weight)\n{\n  const int h_low = floor(h);\n  const int w_low = floor(w);\n  const int h_high = h_low + 1;\n  const int w_high = w_low + 1;\n\n  const scalar_t lh = h - h_low;\n  const scalar_t lw = w - w_low;\n  const scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  const int w_stride = nheads * channels;\n  const int h_stride = width * w_stride;\n  const int h_low_ptr_offset = h_low * h_stride;\n  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n  const int w_low_ptr_offset = w_low * w_stride;\n  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n  const int base_ptr = m * channels + c;\n\n  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n  const scalar_t top_grad_value = top_grad * attn_weight;\n  scalar_t grad_h_weight = 0, grad_w_weight = 0;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n  {\n    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;\n    v1 = bottom_data[ptr1];\n    grad_h_weight -= hw * v1;\n    grad_w_weight -= hh * v1;\n    atomicAdd(grad_value+ptr1, w1*top_grad_value);\n  }\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n  {\n    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;\n    v2 = bottom_data[ptr2];\n    grad_h_weight -= lw * v2;\n    grad_w_weight += hh * v2;\n    atomicAdd(grad_value+ptr2, w2*top_grad_value);\n  }\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n  {\n    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;\n    v3 = bottom_data[ptr3];\n    grad_h_weight += hw * v3;\n    grad_w_weight -= lh * v3;\n    atomicAdd(grad_value+ptr3, w3*top_grad_value); \n  }\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n  {\n    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;\n    v4 = bottom_data[ptr4];\n    grad_h_weight += lw * v4;\n    grad_w_weight += lh * v4;\n    atomicAdd(grad_value+ptr4, w4*top_grad_value);\n  }\n\n  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  atomicAdd(grad_attn_weight, top_grad * val); \n  atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);\n  atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_im2col_gpu_kernel(const int n,\n                                                const scalar_t *data_value, \n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *data_col)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    scalar_t *data_col_ptr = data_col + index;\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n    scalar_t col = 0;\n    \n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;\n        }\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n      }\n    }\n    *data_col_ptr = col;\n  }\n}\n\ntemplate <typename scalar_t, unsigned int blockSize>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];\n    __shared__ scalar_t cache_grad_attn_weight[blockSize];\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n        if (tid == 0)\n        {\n          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];\n          int sid=2;\n          for (unsigned int tid = 1; tid < blockSize; ++tid)\n          {\n            _grad_w += cache_grad_sampling_loc[sid];\n            _grad_h += cache_grad_sampling_loc[sid + 1];\n            _grad_a += cache_grad_attn_weight[tid];\n            sid += 2;\n          }\n          \n          \n          *grad_sampling_loc = _grad_w;\n          *(grad_sampling_loc + 1) = _grad_h;\n          *grad_attn_weight = _grad_a;\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t, unsigned int blockSize>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];\n    __shared__ scalar_t cache_grad_attn_weight[blockSize];\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockSize/2; s>0; s>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        { \n          *grad_sampling_loc = cache_grad_sampling_loc[0];\n          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];\n          *grad_attn_weight = cache_grad_attn_weight[0];\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n        if (tid == 0)\n        {\n          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];\n          int sid=2;\n          for (unsigned int tid = 1; tid < blockDim.x; ++tid)\n          {\n            _grad_w += cache_grad_sampling_loc[sid];\n            _grad_h += cache_grad_sampling_loc[sid + 1];\n            _grad_a += cache_grad_attn_weight[tid];\n            sid += 2;\n          }\n          \n          \n          *grad_sampling_loc = _grad_w;\n          *(grad_sampling_loc + 1) = _grad_h;\n          *grad_attn_weight = _grad_a;\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n            if (tid + (s << 1) < spre)\n            {\n              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];\n              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];\n              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];\n            } \n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        {\n          *grad_sampling_loc = cache_grad_sampling_loc[0];\n          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];\n          *grad_attn_weight = cache_grad_attn_weight[0];\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    extern __shared__ int _s[];\n    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;\n    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;\n    unsigned int tid = threadIdx.x;\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;\n        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;\n        *(cache_grad_attn_weight+threadIdx.x)=0;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);\n        }\n        \n        __syncthreads();\n\n        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)\n        {\n          if (tid < s) {\n            const unsigned int xid1 = tid << 1;\n            const unsigned int xid2 = (tid + s) << 1;\n            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];\n            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];\n            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];\n            if (tid + (s << 1) < spre)\n            {\n              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];\n              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];\n              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];\n            }\n          }\n          __syncthreads();\n        }\n\n        if (tid == 0)\n        {\n          atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);\n          atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);\n          atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);\n        }\n        __syncthreads();\n\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\n__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,\n                                                const scalar_t *grad_col,\n                                                const scalar_t *data_value,\n                                                const int64_t *data_spatial_shapes,\n                                                const int64_t *data_level_start_index, \n                                                const scalar_t *data_sampling_loc,\n                                                const scalar_t *data_attn_weight,\n                                                const int batch_size, \n                                                const int spatial_size, \n                                                const int num_heads,\n                                                const int channels, \n                                                const int num_levels,\n                                                const int num_query,\n                                                const int num_point,\n                                                scalar_t *grad_value,\n                                                scalar_t *grad_sampling_loc,\n                                                scalar_t *grad_attn_weight)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    int _temp = index;\n    const int c_col = _temp % channels;\n    _temp /= channels;\n    const int sampling_index = _temp; \n    const int m_col = _temp % num_heads;\n    _temp /= num_heads;\n    const int q_col = _temp % num_query;\n    _temp /= num_query;\n    const int b_col = _temp;\n\n    const scalar_t top_grad = grad_col[index];\n\n    int data_weight_ptr = sampling_index * num_levels * num_point;\n    int data_loc_w_ptr = data_weight_ptr << 1;\n    const int grad_sampling_ptr = data_weight_ptr;\n    grad_sampling_loc += grad_sampling_ptr << 1;\n    grad_attn_weight += grad_sampling_ptr;\n    const int grad_weight_stride = 1;\n    const int grad_loc_stride = 2;\n    const int qid_stride = num_heads * channels;\n    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;\n\n    for (int l_col=0; l_col < num_levels; ++l_col)\n    {\n      const int level_start_id = data_level_start_index[l_col];\n      const int spatial_h_ptr = l_col << 1;\n      const int spatial_h = data_spatial_shapes[spatial_h_ptr];\n      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];\n      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;\n      const scalar_t *data_value_ptr = data_value + value_ptr_offset;\n      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;\n\n      for (int p_col=0; p_col < num_point; ++p_col)\n      {\n        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];\n        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];\n        const scalar_t weight = data_attn_weight[data_weight_ptr];\n\n        const scalar_t h_im = loc_h * spatial_h - 0.5;\n        const scalar_t w_im = loc_w * spatial_w - 0.5;\n        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)\n        {\n          ms_deform_attn_col2im_bilinear_gm(\n            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,\n            top_grad, weight, grad_value_ptr, \n            grad_sampling_loc, grad_attn_weight);\n        }\n        data_weight_ptr += 1;\n        data_loc_w_ptr += 2;\n        grad_attn_weight += grad_weight_stride;\n        grad_sampling_loc += grad_loc_stride;\n      }\n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\nvoid ms_deformable_im2col_cuda(cudaStream_t stream,\n                              const scalar_t* data_value,\n                              const int64_t* data_spatial_shapes, \n                              const int64_t* data_level_start_index, \n                              const scalar_t* data_sampling_loc,\n                              const scalar_t* data_attn_weight,\n                              const int batch_size,\n                              const int spatial_size, \n                              const int num_heads, \n                              const int channels, \n                              const int num_levels, \n                              const int num_query,\n                              const int num_point,\n                              scalar_t* data_col)\n{\n  const int num_kernels = batch_size * num_query * num_heads * channels;\n  const int num_actual_kernels = batch_size * num_query * num_heads * channels;\n  const int num_threads = CUDA_NUM_THREADS;\n  ms_deformable_im2col_gpu_kernel<scalar_t>\n      <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n          0, stream>>>(\n      num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, \n      batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);\n  \n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in ms_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\ntemplate <typename scalar_t>\nvoid ms_deformable_col2im_cuda(cudaStream_t stream,\n                              const scalar_t* grad_col,\n                              const scalar_t* data_value,\n                              const int64_t * data_spatial_shapes,\n                              const int64_t * data_level_start_index,\n                              const scalar_t * data_sampling_loc,\n                              const scalar_t * data_attn_weight,\n                              const int batch_size, \n                              const int spatial_size, \n                              const int num_heads,\n                              const int channels, \n                              const int num_levels,\n                              const int num_query,\n                              const int num_point, \n                              scalar_t* grad_value,\n                              scalar_t* grad_sampling_loc,\n                              scalar_t* grad_attn_weight)\n{\n  const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;\n  const int num_kernels = batch_size * num_query * num_heads * channels;\n  const int num_actual_kernels = batch_size * num_query * num_heads * channels;\n  if (channels > 1024)\n  {\n    if ((channels & 1023) == 0)\n    {\n      ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n    }\n    else\n    {\n      ms_deformable_col2im_gpu_kernel_gm<scalar_t>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n    }\n  }\n  else{\n    switch(channels)\n    {\n      case 1:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 2:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 4:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 8:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 16:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 32:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 64:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 128:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 256:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 512:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      case 1024:\n        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>\n        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n            0, stream>>>(\n                      num_kernels, \n                      grad_col,\n                      data_value,\n                      data_spatial_shapes,\n                      data_level_start_index, \n                      data_sampling_loc,\n                      data_attn_weight,\n                      batch_size, \n                      spatial_size, \n                      num_heads,\n                      channels, \n                      num_levels,\n                      num_query,\n                      num_point,\n                      grad_value,\n                      grad_sampling_loc,\n                      grad_attn_weight);\n        break;\n      default:\n        if (channels < 64)\n        {\n          ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n        }\n        else\n        {\n          ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>\n          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,\n              num_threads*3*sizeof(scalar_t), stream>>>(\n                        num_kernels, \n                        grad_col,\n                        data_value,\n                        data_spatial_shapes,\n                        data_level_start_index, \n                        data_sampling_loc,\n                        data_attn_weight,\n                        batch_size, \n                        spatial_size, \n                        num_heads,\n                        channels, \n                        num_levels,\n                        num_query,\n                        num_point,\n                        grad_value,\n                        grad_sampling_loc,\n                        grad_attn_weight);\n        }\n    }\n  }\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in ms_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/ms_deform_attn.h",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#pragma once\n\n#include \"cpu/ms_deform_attn_cpu.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/ms_deform_attn_cuda.h\"\n#endif\n\n\nat::Tensor\nms_deform_attn_forward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const int im2col_step)\n{\n    if (value.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return ms_deform_attn_cuda_forward(\n            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    AT_ERROR(\"Not implemented on the CPU\");\n}\n\nstd::vector<at::Tensor>\nms_deform_attn_backward(\n    const at::Tensor &value, \n    const at::Tensor &spatial_shapes,\n    const at::Tensor &level_start_index,\n    const at::Tensor &sampling_loc,\n    const at::Tensor &attn_weight,\n    const at::Tensor &grad_output,\n    const int im2col_step)\n{\n    if (value.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return ms_deform_attn_cuda_backward(\n            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    AT_ERROR(\"Not implemented on the CPU\");\n}\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/src/vision.cpp",
    "content": "/*!\n**************************************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 SenseTime. All Rights Reserved.\n* Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n**************************************************************************************************\n* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n**************************************************************************************************\n*/\n\n/*!\n* Copyright (c) Facebook, Inc. and its affiliates.\n* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n*/\n\n#include \"ms_deform_attn.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"ms_deform_attn_forward\", &ms_deform_attn_forward, \"ms_deform_attn_forward\");\n  m.def(\"ms_deform_attn_backward\", &ms_deform_attn_backward, \"ms_deform_attn_backward\");\n}\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/ops/test.py",
    "content": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseTime. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------------------------------\n# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0\n# ------------------------------------------------------------------------------------------------\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport time\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import gradcheck\n\nfrom functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch\n\n\nN, M, D = 1, 2, 2\nLq, L, P = 2, 2, 2\nshapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()\nlevel_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))\nS = sum([(H*W).item() for H, W in shapes])\n\n\ntorch.manual_seed(3)\n\n\n@torch.no_grad()\ndef check_forward_equal_with_pytorch_double():\n    value = torch.rand(N, S, M, D).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()\n    output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()\n    fwdok = torch.allclose(output_cuda, output_pytorch)\n    max_abs_err = (output_cuda - output_pytorch).abs().max()\n    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()\n\n    print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')\n\n\n@torch.no_grad()\ndef check_forward_equal_with_pytorch_float():\n    value = torch.rand(N, S, M, D).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()\n    output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()\n    fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)\n    max_abs_err = (output_cuda - output_pytorch).abs().max()\n    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()\n\n    print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')\n\n\ndef check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):\n\n    value = torch.rand(N, S, M, channels).cuda() * 0.01\n    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()\n    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5\n    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)\n    im2col_step = 2\n    func = MSDeformAttnFunction.apply\n\n    value.requires_grad = grad_value\n    sampling_locations.requires_grad = grad_sampling_loc\n    attention_weights.requires_grad = grad_attn_weight\n\n    gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))\n\n    print(f'* {gradok} check_gradient_numerical(D={channels})')\n\n\nif __name__ == '__main__':\n    check_forward_equal_with_pytorch_double()\n    check_forward_equal_with_pytorch_float()\n\n    for channels in [30, 32, 64, 71, 1025, 2048, 3096]:\n        check_gradient_numerical(channels, True, True, True)\n\n\n\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_encoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints\n"
  },
  {
    "path": "llava/model/semsam/body/encoder/transformer_encoder_fpn.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport logging\nimport numpy as np\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import xavier_uniform_, constant_, uniform_, normal_\nfrom torch.cuda.amp import autocast\n\nimport fvcore.nn.weight_init as weight_init\nfrom detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm\n\nfrom .registry import register_encoder\nfrom ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn\nfrom ...modules import PositionEmbeddingSine\nfrom ...utils import configurable\n\n\n# This is a modified FPN decoder.\nclass BasePixelDecoder(nn.Module):\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        conv_dim: int,\n        mask_dim: int,\n        mask_on: bool,\n        norm: Optional[Union[str, Callable]] = None,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n        \"\"\"\n        super().__init__()\n\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        feature_channels = [v.channels for k, v in input_shape]\n\n        lateral_convs = []\n        output_convs = []\n\n        use_bias = norm == \"\"\n        for idx, in_channels in enumerate(feature_channels):\n            if idx == len(self.in_features) - 1:\n                output_norm = get_norm(norm, conv_dim)\n                output_conv = Conv2d(\n                    in_channels,\n                    conv_dim,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=use_bias,\n                    norm=output_norm,\n                    activation=F.relu,\n                )\n                weight_init.c2_xavier_fill(output_conv)\n                self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n                lateral_convs.append(None)\n                output_convs.append(output_conv)\n            else:\n                lateral_norm = get_norm(norm, conv_dim)\n                output_norm = get_norm(norm, conv_dim)\n\n                lateral_conv = Conv2d(\n                    in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm\n                )\n                output_conv = Conv2d(\n                    conv_dim,\n                    conv_dim,\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                    bias=use_bias,\n                    norm=output_norm,\n                    activation=F.relu,\n                )\n                weight_init.c2_xavier_fill(lateral_conv)\n                weight_init.c2_xavier_fill(output_conv)\n                self.add_module(\"adapter_{}\".format(idx + 1), lateral_conv)\n                self.add_module(\"layer_{}\".format(idx + 1), output_conv)\n\n                lateral_convs.append(lateral_conv)\n                output_convs.append(output_conv)\n        # Place convs into top-down order (from low to high resolution)\n        # to make the top-down computation in forward clearer.\n        self.lateral_convs = lateral_convs[::-1]\n        self.output_convs = output_convs[::-1]\n\n        self.mask_on = mask_on\n        if self.mask_on:\n            self.mask_dim = mask_dim\n            self.mask_features = Conv2d(\n                conv_dim,\n                mask_dim,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n            )\n            weight_init.c2_xavier_fill(self.mask_features)\n\n        self.maskformer_num_feature_levels = 3  # always use 3 scales\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        ret = {}\n        ret[\"input_shape\"] = {\n            k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n        }\n        ret[\"conv_dim\"] = enc_cfg['CONVS_DIM']\n        ret[\"mask_dim\"] = enc_cfg['MASK_DIM']\n        ret[\"norm\"] = enc_cfg['NORM']\n        return ret\n\n    def forward_features(self, features):\n        multi_scale_features = []\n        num_cur_levels = 0\n        # Reverse feature maps into top-down order (from low to high resolution)\n        for idx, f in enumerate(self.in_features[::-1]):\n            x = features[f]\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            if lateral_conv is None:\n                y = output_conv(x)\n            else:\n                cur_fpn = lateral_conv(x)\n                # Following FPN implementation, we use nearest upsampling here\n                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode=\"nearest\")\n                y = output_conv(y)\n            if num_cur_levels < self.maskformer_num_feature_levels:\n                multi_scale_features.append(y)\n                num_cur_levels += 1\n        \n        mask_features = self.mask_features(y) if self.mask_on else None\n        return mask_features, None, multi_scale_features\n\n    def forward(self, features, targets=None):\n        logger = logging.getLogger(__name__)\n        logger.warning(\"Calling forward() may cause unpredicted behavior of PixelDecoder module.\")\n        return self.forward_features(features)\n\n\nclass TransformerEncoderOnly(nn.Module):\n    def __init__(\n        self,\n        d_model=512,\n        nhead=8,\n        num_encoder_layers=6,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n\n        encoder_layer = TransformerEncoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        self._reset_parameters()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, src, mask, pos_embed):\n        # flatten NxCxHxW to HWxNxC\n        bs, c, h, w = src.shape\n        src = src.flatten(2).permute(2, 0, 1)\n        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)\n        if mask is not None:\n            mask = mask.flatten(1)\n\n        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)\n        return memory.permute(1, 2, 0).view(bs, c, h, w)\n\n\n# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.\nclass TransformerEncoderPixelDecoder(BasePixelDecoder):\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        transformer_dropout: float,\n        transformer_nheads: int,\n        transformer_dim_feedforward: int,\n        transformer_enc_layers: int,\n        transformer_pre_norm: bool,\n        conv_dim: int,\n        mask_dim: int,\n        mask_on: int,\n        norm: Optional[Union[str, Callable]] = None,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            transformer_dropout: dropout probability in transformer\n            transformer_nheads: number of heads in transformer\n            transformer_dim_feedforward: dimension of feedforward network\n            transformer_enc_layers: number of transformer encoder layers\n            transformer_pre_norm: whether to use pre-layernorm or not\n            conv_dims: number of output channels for the intermediate conv layers.\n            mask_dim: number of output channels for the final conv layer.\n            norm (str or callable): normalization for all conv layers\n        \"\"\"\n        super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on)\n\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]  # starting from \"res2\" to \"res5\"\n        feature_strides = [v.stride for k, v in input_shape]\n        feature_channels = [v.channels for k, v in input_shape]\n\n        in_channels = feature_channels[len(self.in_features) - 1]\n        self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)\n        weight_init.c2_xavier_fill(self.input_proj)\n        self.transformer = TransformerEncoderOnly(\n            d_model=conv_dim,\n            dropout=transformer_dropout,\n            nhead=transformer_nheads,\n            dim_feedforward=transformer_dim_feedforward,\n            num_encoder_layers=transformer_enc_layers,\n            normalize_before=transformer_pre_norm,\n        )\n        N_steps = conv_dim // 2\n        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)\n\n        # update layer\n        use_bias = norm == \"\"\n        output_norm = get_norm(norm, conv_dim)\n        output_conv = Conv2d(\n            conv_dim,\n            conv_dim,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            bias=use_bias,\n            norm=output_norm,\n            activation=F.relu,\n        )\n        weight_init.c2_xavier_fill(output_conv)\n        delattr(self, \"layer_{}\".format(len(self.in_features)))\n        self.add_module(\"layer_{}\".format(len(self.in_features)), output_conv)\n        self.output_convs[0] = output_conv\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n\n        ret = super().from_config(cfg, input_shape)\n        ret[\"transformer_dropout\"] = dec_cfg['DROPOUT']\n        ret[\"transformer_nheads\"] = dec_cfg['NHEADS']\n        ret[\"transformer_dim_feedforward\"] = dec_cfg['DIM_FEEDFORWARD']\n        ret[\"transformer_enc_layers\"] = enc_cfg['TRANSFORMER_ENC_LAYERS']  # a separate config\n        ret[\"transformer_pre_norm\"] = dec_cfg['PRE_NORM']\n\n        ret['mask_on'] = cfg['MODEL']['DECODER']['MASK']\n        return ret\n\n    def forward_features(self, features):\n        multi_scale_features = []\n        num_cur_levels = 0\n        \n        # Reverse feature maps into top-down order (from low to high resolution)\n        for idx, f in enumerate(self.in_features[::-1]):\n            x = features[f]\n            lateral_conv = self.lateral_convs[idx]\n            output_conv = self.output_convs[idx]\n            if lateral_conv is None:\n                transformer = self.input_proj(x)\n                pos = self.pe_layer(x)\n                transformer = self.transformer(transformer, None, pos)\n                y = output_conv(transformer)\n                # save intermediate feature as input to Transformer decoder\n                transformer_encoder_features = transformer\n            else:\n                cur_fpn = lateral_conv(x)\n                # Following FPN implementation, we use nearest upsampling here\n                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode=\"nearest\")\n                y = output_conv(y)\n            if num_cur_levels < self.maskformer_num_feature_levels:\n                multi_scale_features.append(y)\n                num_cur_levels += 1\n\n        mask_features = self.mask_features(y) if self.mask_on else None\n        return mask_features, transformer_encoder_features, multi_scale_features\n\n    def forward(self, features, targets=None):\n        logger = logging.getLogger(__name__)\n        logger.warning(\"Calling forward() may cause unpredicted behavior of PixelDecoder module.\")\n        return self.forward_features(features)\n\n\n\n@register_encoder\ndef get_transformer_encoder_fpn(cfg, input_shape):\n    \"\"\"\n    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.\n    \"\"\"\n    model = TransformerEncoderPixelDecoder(cfg, input_shape)    \n    forward_features = getattr(model, \"forward_features\", None)\n    if not callable(forward_features):\n        raise ValueError(\n            \"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. \"\n            f\"Please implement forward_features for {name} to only return mask features.\"\n        )\n    return model"
  },
  {
    "path": "llava/model/semsam/body/openseed_head.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------------\nimport logging\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch import nn\n\nfrom detectron2.layers import Conv2d, ShapeSpec, get_norm\nfrom detectron2.modeling import SEM_SEG_HEADS_REGISTRY\n\nfrom .registry import register_body\nfrom .encoder import build_encoder\nfrom .decoder import build_decoder\nfrom ..utils import configurable\n\n\nclass MaskDINOHead(nn.Module):\n    @configurable\n    def __init__(\n        self,\n        input_shape: Dict[str, ShapeSpec],\n        *,\n        num_classes: int,\n        pixel_decoder: nn.Module,\n        loss_weight: float = 1.0,\n        ignore_value: int = -1,\n        transformer_predictor: nn.Module,\n    ):\n        \"\"\"\n        Args:\n            input_shape: shapes (channels and stride) of the input features\n            num_classes: number of classes to predict\n            pixel_decoder: the pixel decoder module\n            loss_weight: loss weight\n            ignore_value: category id to be ignored during training.\n            transformer_predictor: the transformer decoder that makes prediction\n            transformer_in_feature: input feature name to the transformer_predictor\n        \"\"\"\n        super().__init__()\n        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)\n        self.in_features = [k for k, v in input_shape]\n        self.ignore_value = ignore_value\n        self.common_stride = 4\n        self.loss_weight = loss_weight\n\n        self.pixel_decoder = pixel_decoder\n        self.predictor = transformer_predictor\n\n        self.num_classes = num_classes\n\n    @classmethod\n    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):\n        enc_cfg = cfg['MODEL']['ENCODER']\n        dec_cfg = cfg['MODEL']['DECODER']\n        transformer_predictor_in_channels = enc_cfg['CONVS_DIM']\n\n        return {\n            \"input_shape\": {\n                k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']\n            },\n            \"ignore_value\": enc_cfg['IGNORE_VALUE'],\n            \"num_classes\": enc_cfg.get('NUM_CLASSES', None),\n            \"pixel_decoder\": build_encoder(cfg, input_shape),\n            \"loss_weight\": enc_cfg['LOSS_WEIGHT'],\n            \"transformer_predictor\": build_decoder(\n                cfg,\n                transformer_predictor_in_channels,\n                lang_encoder,\n                mask_classification=True,\n                extra=extra,\n            ),\n        }\n\n    def forward(self, features, mask=None, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}):\n        return self.layers(features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)\n\n    def layers(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}):\n        mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask)\n        if task == 'teacher':\n            predictions = self.predictor.forward_teacher(multi_scale_features, mask_features, mask, targets=targets,\n                                                         target_queries=target_queries, target_vlp=target_vlp,\n                                                         task=task, extra=extra)\n        else:\n            predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets,\n                                         target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)\n        return predictions\n\n\n@register_body\ndef get_maskdino_head(cfg, input_shape, lang_encoder, extra):\n    return MaskDINOHead(cfg, input_shape, lang_encoder, extra)"
  },
  {
    "path": "llava/model/semsam/body/registry.py",
    "content": "_model_entrypoints = {}\n\n\ndef register_body(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/semsam/body/transformer_blocks.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py\n\"\"\"\nTransformer class.\n\nCopy-paste from torch.nn.Transformer with modifications:\n    * positional encodings are passed in MHattention\n    * extra LN at the end of encoder is removed\n    * decoder returns a stack of activations from all decoding layers\n\"\"\"\nimport copy\nfrom typing import List, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        d_model=512,\n        nhead=8,\n        num_encoder_layers=6,\n        num_decoder_layers=6,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n        return_intermediate_dec=False,\n    ):\n        super().__init__()\n\n        encoder_layer = TransformerEncoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        decoder_layer = TransformerDecoderLayer(\n            d_model, nhead, dim_feedforward, dropout, activation, normalize_before\n        )\n        decoder_norm = nn.LayerNorm(d_model)\n        self.decoder = TransformerDecoder(\n            decoder_layer,\n            num_decoder_layers,\n            decoder_norm,\n            return_intermediate=return_intermediate_dec,\n        )\n\n        self._reset_parameters()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, src, mask, query_embed, pos_embed):\n        # flatten NxCxHxW to HWxNxC\n        bs, c, h, w = src.shape\n        src = src.flatten(2).permute(2, 0, 1)\n        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)\n        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)\n        if mask is not None:\n            mask = mask.flatten(1)\n\n        tgt = torch.zeros_like(query_embed)\n        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)\n        hs = self.decoder(\n            tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed\n        )\n        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)\n\n\nclass TransformerEncoder(nn.Module):\n    def __init__(self, encoder_layer, num_layers, norm=None):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n\n    def forward(\n        self,\n        src,\n        mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        output = src\n\n        for layer in self.layers:\n            output = layer(\n                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos\n            )\n\n        if self.norm is not None:\n            output = self.norm(output)\n\n        return output\n\n\nclass TransformerDecoder(nn.Module):\n    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n        super().__init__()\n        self.layers = _get_clones(decoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n        self.return_intermediate = return_intermediate\n\n    def forward(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        output = tgt\n\n        intermediate = []\n\n        for layer in self.layers:\n            output = layer(\n                output,\n                memory,\n                tgt_mask=tgt_mask,\n                memory_mask=memory_mask,\n                tgt_key_padding_mask=tgt_key_padding_mask,\n                memory_key_padding_mask=memory_key_padding_mask,\n                pos=pos,\n                query_pos=query_pos,\n            )\n            if self.return_intermediate:\n                intermediate.append(self.norm(output))\n\n        if self.norm is not None:\n            output = self.norm(output)\n            if self.return_intermediate:\n                intermediate.pop()\n                intermediate.append(output)\n\n        if self.return_intermediate:\n            return torch.stack(intermediate)\n\n        return output.unsqueeze(0)\n\n\nclass TransformerEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        nhead,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        q = k = self.with_pos_embed(src, pos)\n\n        src2 = self.self_attn(\n            q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n        )[0]\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n        src = src + self.dropout2(src2)\n        src = self.norm2(src)\n        return src\n\n    def forward_pre(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        src2 = self.norm1(src)\n        q = k = self.with_pos_embed(src2, pos)\n        src2 = self.self_attn(\n            q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n        )[0]\n        src = src + self.dropout1(src2)\n        src2 = self.norm2(src)\n        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n        src = src + self.dropout2(src2)\n        return src\n\n    def forward(\n        self,\n        src,\n        src_mask: Optional[Tensor] = None,\n        src_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n    ):\n        if self.normalize_before:\n            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n        return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n\n\nclass TransformerDecoderLayer(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        nhead,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        normalize_before=False,\n    ):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        self.norm3 = nn.LayerNorm(d_model)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.dropout3 = nn.Dropout(dropout)\n\n        self.activation = _get_activation_fn(activation)\n        self.normalize_before = normalize_before\n\n    def with_pos_embed(self, tensor, pos: Optional[Tensor]):\n        return tensor if pos is None else tensor + pos\n\n    def forward_post(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        q = k = self.with_pos_embed(tgt, query_pos)\n        tgt2 = self.self_attn(\n            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask\n        )[0]\n        tgt = tgt + self.dropout1(tgt2)\n        tgt = self.norm1(tgt)\n        tgt2 = self.multihead_attn(\n            query=self.with_pos_embed(tgt, query_pos),\n            key=self.with_pos_embed(memory, pos),\n            value=memory,\n            attn_mask=memory_mask,\n            key_padding_mask=memory_key_padding_mask,\n        )[0]\n        tgt = tgt + self.dropout2(tgt2)\n        tgt = self.norm2(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout3(tgt2)\n        tgt = self.norm3(tgt)\n        return tgt\n\n    def forward_pre(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        tgt2 = self.norm1(tgt)\n        q = k = self.with_pos_embed(tgt2, query_pos)\n        tgt2 = self.self_attn(\n            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask\n        )[0]\n        tgt = tgt + self.dropout1(tgt2)\n        tgt2 = self.norm2(tgt)\n        tgt2 = self.multihead_attn(\n            query=self.with_pos_embed(tgt2, query_pos),\n            key=self.with_pos_embed(memory, pos),\n            value=memory,\n            attn_mask=memory_mask,\n            key_padding_mask=memory_key_padding_mask,\n        )[0]\n        tgt = tgt + self.dropout2(tgt2)\n        tgt2 = self.norm3(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n        tgt = tgt + self.dropout3(tgt2)\n        return tgt\n\n    def forward(\n        self,\n        tgt,\n        memory,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n        pos: Optional[Tensor] = None,\n        query_pos: Optional[Tensor] = None,\n    ):\n        if self.normalize_before:\n            return self.forward_pre(\n                tgt,\n                memory,\n                tgt_mask,\n                memory_mask,\n                tgt_key_padding_mask,\n                memory_key_padding_mask,\n                pos,\n                query_pos,\n            )\n        return self.forward_post(\n            tgt,\n            memory,\n            tgt_mask,\n            memory_mask,\n            tgt_key_padding_mask,\n            memory_key_padding_mask,\n            pos,\n            query_pos,\n        )\n\n\ndef _get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(f\"activation should be relu/gelu, not {activation}.\")\n"
  },
  {
    "path": "llava/model/semsam/language/LangEncoder/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .build import build_lang_encoder\nfrom .build import build_tokenizer\n\nfrom .transformer import *"
  },
  {
    "path": "llava/model/semsam/language/LangEncoder/build.py",
    "content": "import os\n\nfrom transformers import CLIPTokenizer, CLIPTokenizerFast\nfrom transformers import AutoTokenizer\n\nfrom .registry import lang_encoders\nfrom .registry import is_lang_encoder\n\n\ndef build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):\n    model_name = config_encoder['NAME']\n\n    if not is_lang_encoder(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)\n\n\ndef build_tokenizer(config_encoder):\n    tokenizer = None\n    os.environ['TOKENIZERS_PARALLELISM'] = 'true'\n    if config_encoder['TOKENIZER'] == 'clip':\n        pretrained_tokenizer = config_encoder.get(\n            'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'\n        )\n        tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)\n        tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})\n    elif config_encoder['TOKENIZER'] == 'clip-fast':\n        pretrained_tokenizer = config_encoder.get(\n            'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'\n        )\n        tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])\n\n    return tokenizer\n"
  },
  {
    "path": "llava/model/semsam/language/LangEncoder/registry.py",
    "content": "_lang_encoders = {}\n\n\ndef register_lang_encoder(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n\n    _lang_encoders[model_name] = fn\n\n    return fn\n\n\ndef lang_encoders(model_name):\n    return _lang_encoders[model_name]\n\n\ndef is_lang_encoder(model_name):\n    return model_name in _lang_encoders\n"
  },
  {
    "path": "llava/model/semsam/language/LangEncoder/transformer.py",
    "content": "from collections import OrderedDict\nfrom typing import Tuple, Union\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom timm.models.layers import DropPath, trunc_normal_\n\nfrom .registry import register_lang_encoder\nfrom detectron2.utils.comm import is_main_process\nfrom utils.model import register_norm_module\n\nlogger = logging.getLogger(__name__)\n\n\n@register_norm_module\nclass LayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        \"\"\"Construct a layernorm module in the TF style (epsilon inside the square root).\n        \"\"\"\n        super(LayerNorm, self).__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.bias = nn.Parameter(torch.zeros(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        pdtype = x.dtype\n        x = x.float()\n        u = x.mean(-1, keepdim=True)\n        s = (x - u).pow(2).mean(-1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.variance_epsilon)\n        return self.weight * x.to(pdtype) + self.bias\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self,\n                 d_model: int,\n                 n_head: int,\n                 attn_mask: torch.Tensor = None,\n                 drop_path: float = 0.0):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n            (\"gelu\", QuickGELU()),\n            (\"c_proj\", nn.Linear(d_model * 4, d_model))\n        ]))\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):\n        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \\\n            if self.attn_mask is not None else None\n\n\n        return self.attn(\n            x, x, x,\n            key_padding_mask=key_padding_mask,\n            need_weights=False,\n            attn_mask=self.attn_mask\n        )[0]\n\n    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):\n        x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))\n        x = x + self.drop_path(self.mlp(self.ln_2(x)))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self,\n                 context_length: int,\n                 vocab_size: int,\n                 width: int,\n                 layers: int,\n                 heads: int,\n                 drop_path: float = 0.0,\n                 autogressive: bool =True):\n        super().__init__()\n\n        self.token_embedding = nn.Embedding(vocab_size, width)\n\n        self.context_length = context_length\n        self.positional_embedding = nn.Parameter(\n            torch.empty(self.context_length, width)\n        )\n\n        self.width = width\n        self.layers = layers\n        self.autogressive = autogressive\n        attn_mask = self.build_attention_mask() if autogressive else None\n        dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]  # stochastic depth decay rule\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(width, heads, attn_mask, dpr[i])\n                for i in range(layers)\n            ]\n        )\n\n        self.ln_final = LayerNorm(width)\n\n        trunc_normal_(self.positional_embedding, std=.02)\n        # nn.init.normal_(self.token_embedding, std=.02)\n        trunc_normal_(self.token_embedding.weight, std=.02)\n        self.apply(self._init_weights)\n\n    @property\n    def dim_out(self):\n        return self.width\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d)):\n            if is_main_process():\n                logger.info('=> init weight of Linear/Conv2d from trunc norm')\n            trunc_normal_(m.weight, std=0.02)\n            if m.bias is not None:\n                if is_main_process():\n                    logger.info('=> init bias of Linear/Conv2d to zeros')\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):\n            nn.init.constant_(m.bias, 0)\n\n    def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained, map_location='cpu')\n            logging.info(f'=> loading pretrained model {pretrained}')\n            model_dict = self.state_dict()\n            stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x\n            pretrained_dict = {\n                stripped_key(k): v for k, v in pretrained_dict.items()\n                if stripped_key(k) in model_dict.keys()\n            }\n            need_init_state_dict = {}\n            for k, v in pretrained_dict.items():\n                need_init = (\n                    k.split('.')[0] in pretrained_layers\n                    or pretrained_layers[0] == '*'\n                )\n                if need_init:\n                    if verbose:\n                        logger.info(f'=> init {k} from {pretrained}')\n\n                    if 'positional_embedding' in k and v.size() != model_dict[k].size():\n                        positional_embedding_pretrained = v\n                        positional_embedding_current = model_dict[k]\n                        L1, nH1 = positional_embedding_pretrained.size()\n                        L2, nH2 = positional_embedding_current.size()\n                        if nH1 != nH2:\n                            logger.info(f\"Error in loading {k}, passing\")\n                        else:\n                            if L1 != L2:\n                                logger.info(\n                                    '=> load_pretrained: resized variant: {} to {}'\n                                        .format((L1, nH1), (L2, nH2))\n                                )\n\n                                posemb = positional_embedding_pretrained.float()\n                                posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)\n                                posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')\n                                posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)\n                                v = posemb_grid\n\n                    need_init_state_dict[k] = v\n\n            self.load_state_dict(need_init_state_dict, strict=False)\n\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\n            'positional_embedding',\n            'token_embedding',\n        }\n\n    def forward(self, input_ids, attention_mask=None):\n        key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None\n        # key_padding_mask = (input_ids == 0) if not self.autogressive else None\n        x = self.token_embedding(input_ids)  # [batch_size, n_ctx, d_model]\n        x = x + self.positional_embedding\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        for block in self.resblocks:\n            x = block(x, key_padding_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        x = self.ln_final(x)\n\n        return {'last_hidden_state': x}\n\n\n@register_lang_encoder\ndef lang_encoder(config_encoder, tokenizer, verbose, **kwargs):\n    transformer = Transformer(\n        context_length=config_encoder['CONTEXT_LENGTH'],\n        vocab_size=tokenizer.vocab_size,\n        width=config_encoder['WIDTH'],\n        layers=config_encoder['LAYERS'],\n        heads=config_encoder['HEADS'],\n        autogressive=config_encoder.get('AUTOGRESSIVE', True)\n    )\n\n    if config_encoder.get('LOAD_PRETRAINED', False):\n        transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))\n    return transformer\n"
  },
  {
    "path": "llava/model/semsam/language/__init__.py",
    "content": "# from .vlpencoder import *\n# from .encoder import *\n# from .fixencoder import *\n# from .loss import *\n# from .modeling_llama_os import LlamaForCausalLM\n# # from .modeling_llama_os_lora import LlamaForCausalLMLora\n# from .llama_encoder import *\n# from .build import build_language_encoder\n"
  },
  {
    "path": "llava/model/semsam/language/build.py",
    "content": "from .registry import model_entrypoints\nfrom .registry import is_model\n\n\ndef build_language_encoder(config, **kwargs):\n    model_name = config['MODEL']['TEXT']['ARCH']\n    if model_name=='noencoder':\n        return None\n\n    if not is_model(model_name):\n        raise ValueError(f'Unkown model: {model_name}')\n\n    return model_entrypoints(model_name)(config, **kwargs)"
  },
  {
    "path": "llava/model/semsam/language/encoder.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\n\n\nclass LanguageEncoder(nn.Module):\n\n    @configurable\n    def __init__(\n        self,\n        tokenizer,\n        tokenizer_type,\n        lang_encoder,\n        lang_projection,\n        max_token_num,\n    ):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.tokenizer_type = tokenizer_type\n        self.lang_encoder = lang_encoder\n        self.lang_proj = lang_projection\n        self.max_token_num = max_token_num\n        self.logit_scale = nn.Parameter(torch.ones([]))\n\n    @classmethod\n    def from_config(cls, cfg):\n        # build up text encoder\n        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])\n        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']\n        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])\n        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']\n        \n        dim_lang = cfg['MODEL']['TEXT']['WIDTH']\n        dim_projection = cfg['MODEL']['DIM_PROJ']\n        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))\n        trunc_normal_(lang_projection, std=.02)\n\n        return {\n            \"tokenizer\": tokenizer,\n            \"tokenizer_type\": tokenizer_type,\n            \"lang_encoder\": lang_encoder,\n            \"lang_projection\": lang_projection,\n            \"max_token_num\": max_token_num,\n        }\n\n    # @torch.no_grad()\n    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):\n        if not is_eval:\n            if prompt:\n                # randomly sample one template\n                arbitary_concepts = [\n                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \\\n                    for label in range(len(class_names))\n                ]\n                if add_bgd:\n                    arbitary_concepts.append(\"A background in coco.\")\n            else:\n                arbitary_concepts = class_names\n            \n            input_ids = []\n            attention_masks = []\n            for txt in arbitary_concepts:\n                tokens = self.tokenizer(\n                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                )\n                tokens['input_ids'].squeeze_()\n                tokens['attention_mask'].squeeze_()\n\n                input_ids.append(tokens['input_ids'])\n                attention_masks.append(tokens['attention_mask'])\n\n            arbitary_tokens = torch.stack(input_ids)\n            arbitary_attention_masks = torch.stack(attention_masks)\n\n            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)\n            setattr(self, '{}_text_embeddings'.format(name), text_emb)\n        else:\n            with torch.no_grad():\n                def extract_mean_emb(txts):\n                    tokens = self.tokenizer(\n                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                    )\n                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)\n                    clss_embedding = clss_embedding.mean(dim=0)\n                    clss_embedding /= clss_embedding.norm()\n                    return clss_embedding\n\n                templates = get_prompt_templates()\n                clss_embeddings = []\n                for clss in class_names:\n                    txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                if add_bgd:\n                    txts = [\"A background in coco.\"]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                text_emb = torch.stack(clss_embeddings, dim=0)\n                setattr(self, '{}_text_embeddings'.format(name), text_emb)\n\n    # @torch.no_grad()\n    def forward_language(self, texts, norm=True):\n        x = self.lang_encoder(*texts)\n        x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            x = x[:, 0]\n\n        x = x @ self.lang_proj\n        if norm:\n            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)\n        return x\n    \n    def compute_similarity(self, v_emb, name='default'):\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = getattr(self, '{}_text_embeddings'.format(name))\n        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    return LanguageEncoder(cfg)"
  },
  {
    "path": "llava/model/semsam/language/fixencoder.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\n\n\nclass LanguageEncoder(nn.Module):\n\n    @configurable\n    def __init__(\n        self,\n        tokenizer,\n        tokenizer_type,\n        lang_encoder,\n        lang_projection,\n        max_token_num,\n    ):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.tokenizer_type = tokenizer_type\n        self.lang_encoder = lang_encoder\n        self.lang_proj = lang_projection\n        self.max_token_num = max_token_num\n        self.logit_scale = nn.Parameter(torch.ones([]))\n\n    @classmethod\n    def from_config(cls, cfg):\n        # build up text encoder\n        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])\n        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']\n        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])\n        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']\n        \n        dim_lang = cfg['MODEL']['TEXT']['WIDTH']\n        dim_projection = cfg['MODEL']['DIM_PROJ']\n        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))\n        trunc_normal_(lang_projection, std=.02)\n\n        return {\n            \"tokenizer\": tokenizer,\n            \"tokenizer_type\": tokenizer_type,\n            \"lang_encoder\": lang_encoder,\n            \"lang_projection\": lang_projection,\n            \"max_token_num\": max_token_num,\n        }\n\n    @torch.no_grad()\n    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):\n        if not is_eval:\n            if prompt:\n                # randomly sample one template\n                arbitary_concepts = [\n                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \\\n                    for label in range(len(class_names))\n                ]\n                if add_bgd:\n                    arbitary_concepts.append(\"A background in coco.\")\n            else:\n                arbitary_concepts = class_names\n            \n            input_ids = []\n            attention_masks = []\n            for txt in arbitary_concepts:\n                tokens = self.tokenizer(\n                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                )\n                tokens['input_ids'].squeeze_()\n                tokens['attention_mask'].squeeze_()\n\n                input_ids.append(tokens['input_ids'])\n                attention_masks.append(tokens['attention_mask'])\n\n            arbitary_tokens = torch.stack(input_ids)\n            arbitary_attention_masks = torch.stack(attention_masks)\n\n            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)\n            setattr(self, '{}_text_embeddings'.format(name), text_emb)\n        else:\n            with torch.no_grad():\n                def extract_mean_emb(txts):\n                    tokens = self.tokenizer(\n                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                    )\n                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)\n                    clss_embedding = clss_embedding.mean(dim=0)\n                    clss_embedding /= clss_embedding.norm()\n                    return clss_embedding\n\n                templates = get_prompt_templates()\n                clss_embeddings = []\n                for clss in class_names:\n                    txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                if add_bgd:\n                    txts = [\"A background in coco.\"]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                text_emb = torch.stack(clss_embeddings, dim=0)\n                setattr(self, '{}_text_embeddings'.format(name), text_emb)\n\n    @torch.no_grad()\n    def forward_language(self, texts, norm=True):\n        x = self.lang_encoder(*texts)\n        x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            x = x[:, 0]\n\n        x = x @ self.lang_proj\n        if norm:\n            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)\n        return x\n    \n    @torch.no_grad() # FIXME hack to freeze all parameters\n    def compute_similarity(self, v_emb, name='default'):\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = getattr(self, '{}_text_embeddings'.format(name))\n        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    return LanguageEncoder(cfg)"
  },
  {
    "path": "llava/model/semsam/language/llama_encoder.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence\n\nimport torch\n\nimport transformers\nfrom torch.utils.data import Dataset\nfrom transformers import Trainer\n\nfrom llava import conversation as conversation_lib\n\nfrom PIL import Image\nimport torch.nn as nn\n# from openseed.BaseModel import BaseModel\n# from openseed import build_model\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\nfrom openseed.language import LlamaForCausalLM\n\n# TODO: import and use code from ../data/dataset.py\n\nIGNORE_INDEX = -100\nDEFAULT_PAD_TOKEN = \"[PAD]\"\nDEFAULT_EOS_TOKEN = \"</s>\"\nDEFAULT_BOS_TOKEN = \"</s>\"\nDEFAULT_UNK_TOKEN = \"<unk>\"\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\nDEFAULT_IM_START_TOKEN = \"<im_start>\"\nDEFAULT_IM_END_TOKEN = \"<im_end>\"\nDEFAULT_OBJECT_START_TOKEN = \"<obj_start>\"\nDEFAULT_OBJECT_END_TOKEN = \"<obj_end>\"\nENC_LENS=[140*64,140*16,140*4,140]\nENC_ID=-1\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    freeze_backbone: bool = field(default=False)\n    dbg: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    config_file: Optional[str] = field(default=\"\")\n    os_weights:Optional[str]=field(default=\"\")\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    pretrain_obj_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_token_len: int = 0\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    # dbg: bool = field(default=False)\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    multimodal_cfg: dict,\n    cur_token_len: int,\n) -> Dict:\n    is_multimodal = multimodal_cfg['is_multimodal']\n    # image_token_len = multimodal_cfg['image_token_len']\n    image_token_len = cur_token_len\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len\n            if multimodal_cfg['use_im_start_end']:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n    input_ids = conversations_tokenized[\"input_ids\"]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source],\n                                      tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass SupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer):\n        super(SupervisedDataset, self).__init__()\n        logging.warning(\"Loading data...\")\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        logging.warning(\"Formatting inputs...\")\n        sources = [example[\"conversations\"] for example in list_data_dict]\n        data_dict = preprocess(sources, tokenizer)\n\n        self.input_ids = data_dict[\"input_ids\"]\n        self.labels = data_dict[\"labels\"]\n\n    def __len__(self):\n        return len(self.input_ids)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        return dict(input_ids=self.input_ids[i], labels=self.labels[i])\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 multimodal_cfg: dict):\n        super(LazySupervisedDataset, self).__init__()\n        logging.warning(\"Loading data...\")\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        logging.warning(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.multimodal_cfg = multimodal_cfg\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        sources = self.list_data_dict[i]\n        if isinstance(i, int):\n            sources = [sources]\n        assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n        try:\n            if 'image' in sources[0]:\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.multimodal_cfg['image_folder']\n                processor = self.multimodal_cfg['image_processor']\n                image = Image.open(os.path.join(image_folder, image_file))\n                if self.multimodal_cfg['image_aspect_ratio'] == 'keep':\n                    max_hw, min_hw = max(image.size), min(image.size)\n                    aspect_ratio = max_hw / min_hw\n                    max_len, min_len = 1333, 800\n                    shortest_edge = int(min(max_len / aspect_ratio, min_len))\n                    try:\n                        image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={\"shortest_edge\": shortest_edge}, do_rescale=False, do_normalize=False)['pixel_values'][0]\n                    except Exception:\n                        return self.__getitem__(i + 1)\n                else:\n                    # try:\n                    image = processor.preprocess(image, return_tensors='pt', do_rescale=False, do_normalize=False, do_center_crop=False,size=(640,64*14))[\n                            'pixel_values'][0]\n                    # except Exception:\n                    #     return self.__getitem__(i+1)\n                # FIXME: cur_token_len should be num_queries when using det\n                # cur_token_len = (image.shape[1]//14) * (image.shape[2]//14)   # FIXME: 14 is hardcoded patch size\n                cur_token_len = ENC_LENS[ENC_ID]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.multimodal_cfg, cur_token_len)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n        except Exception:\n            return self.__getitem__(i + 1)\n        data_dict = preprocess(\n            sources,\n            self.tokenizer)\n        if isinstance(i, int):\n            data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                             labels=data_dict[\"labels\"][0])\n\n        # image exist in the data\n        if 'image' in self.list_data_dict[i]:\n            data_dict['image'] = image\n        elif self.multimodal_cfg['is_multimodal']:\n            # image does not exist in the data, but the model is multimodal\n            crop_size = self.multimodal_cfg['image_processor'].crop_size\n            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n        return data_dict\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    dataset_cls = (LazySupervisedDataset\n                   if data_args.lazy_preprocess else SupervisedDataset)\n    train_dataset = dataset_cls(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                multimodal_cfg=dict(\n                                    is_multimodal=data_args.is_multimodal,\n                                    image_token_len=data_args.image_token_len,\n                                    image_folder=data_args.image_folder,\n                                    image_aspect_ratio=data_args.image_aspect_ratio,\n                                    use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False),\n                                    image_processor=getattr(data_args, 'image_processor', None)))\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\n# from detectron2.config import get_cfg, CfgNode\nfrom detectron2.config import LazyConfig, instantiate\n# from detectron2.utils.logger import setup_logger\n# from detectron2.engine import default_setup\ndef setup(config_file):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = LazyConfig.load(config_file)\n    # cfg = LazyConfig.apply_overrides(cfg, args.opts)\n    # cfg.freeze()\n    # default_setup(cfg, args)\n    # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name=\"maskdino\")\n    return cfg\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    llama_cfg = cfg['MODEL']['LLAMA']\n    if llama_cfg['load_fp16']:\n        return LlamaForCausalLM.from_pretrained(\n            llama_cfg['model_name_or_path'],\n            cache_dir=llama_cfg['cache_dir'],\n            torch_dtype=torch.float16\n        )\n    else:\n        return LlamaForCausalLM.from_pretrained(\n            llama_cfg['model_name_or_path'],\n            cache_dir=llama_cfg['cache_dir'],\n            # torch_dtype=torch.float16\n        )\n\ndef train():\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    if model_args.dbg:\n        training_args._n_gpu = 1\n    ENC_ID=model_args.mm_vision_select_layer\n    model = LlamaForCausalLM.from_pretrained(\n        model_args.model_name_or_path,\n        cache_dir=training_args.cache_dir,\n    )\n    cfg = setup(model_args.config_file)\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n        model_args.model_name_or_path,\n        cache_dir=training_args.cache_dir,\n        model_max_length=training_args.model_max_length,\n        padding_side=\"right\",\n        use_fast=False,\n    )\n    if tokenizer.pad_token is None:\n        smart_tokenizer_and_embedding_resize(\n            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),\n            tokenizer=tokenizer,\n            model=model,\n        )\n    if \"llama\" in model_args.model_name_or_path:\n        tokenizer.add_special_tokens({\n            \"eos_token\": DEFAULT_EOS_TOKEN,\n            \"bos_token\": DEFAULT_BOS_TOKEN,\n            \"unk_token\": DEFAULT_UNK_TOKEN,\n        })\n    if model_args.vision_tower is not None:\n        model.config.mm_vision_tower = model_args.vision_tower\n\n        from transformers import CLIPVisionModel\n        from llava.train.image_processing_gptv import CLIPImageProcessor\n        dtype = torch.float32\n        if training_args.fp16:\n            dtype = torch.float16\n        if training_args.bf16:\n            dtype = torch.bfloat16\n        openseed_vision = BaseModel(cfg, build_model(cfg)).cuda()\n        # if not model_args.dbg:\n        checkpoint = torch.load(model_args.os_weights, map_location='cpu')\n        model_dict = openseed_vision.state_dict()\n        pretrained_dict = {\"model.\"+k: v for k, v in checkpoint.items() if \"model.\"+k in model_dict}\n        model_dict.update(pretrained_dict)\n        openseed_vision.load_state_dict(model_dict)\n        # openseed_vision.stat\n        if not hasattr(model.model, 'vision_tower'):\n            vision_tower = CLIPVisionModel.from_pretrained(model_args.vision_tower)\n        else:\n            vision_tower = model.model.vision_tower[0]\n\n        image_processor = CLIPImageProcessor.from_pretrained(model_args.vision_tower)\n        image_processor.size['shortest_edge'] = 800\n        vision_config = vision_tower.config\n        vision_tower=openseed_vision\n        vision_tower.config=vision_config\n        # vision_tower.num_queries=300\n        # vision_tower.idx=model_args.enc_idx\n\n        vision_tower.idx=ENC_ID\n        vision_tower.num_enc_tokens=ENC_LENS[vision_tower.idx]\n        vision_tower.dim_queries=256\n        # num_patches = (vision_config.image_size // vision_config.patch_size) ** 2\n        num_patches=vision_tower.num_enc_tokens\n        data_args.image_token_len = num_patches\n        data_args.image_processor = image_processor\n        data_args.is_multimodal = True\n\n        vision_tower.requires_grad_(False)\n        # model.model.vision_tower = vision_tower\n        # HACK: for FSDP\n        vision_tower.to(device=training_args.device)\n        model.model.vision_tower = [vision_tower]\n\n        model.config.use_mm_proj = True\n        model.config.mm_hidden_size = vision_config.hidden_size=vision_tower.dim_queries\n        model.config.mm_vision_select_layer = model_args.mm_vision_select_layer\n        if not hasattr(model.model, 'mm_projector') or model.model.mm_projector.weight.shape[1]!=vision_config.hidden_size:\n            mm_projector = nn.Linear(vision_config.hidden_size, model.config.hidden_size)\n            model.model.mm_projector = mm_projector\n        else:\n            mm_projector = model.model.mm_projector\n\n        if not hasattr(model.model, 'obj_projector'):\n            obj_projector = nn.Linear(vision_config.hidden_size+4, model.config.hidden_size)\n            model.model.obj_projector = obj_projector\n        else:\n            obj_projector = model.model.obj_projector\n\n        if model_args.pretrain_mm_mlp_adapter is not None:\n            mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')\n            mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})\n\n        if model_args.pretrain_obj_mlp_adapter is not None:\n            obj_projector_weights = torch.load(model_args.pretrain_obj_mlp_adapter, map_location='cpu')\n            obj_projector.load_state_dict({k.split('.')[-1]: v for k, v in obj_projector_weights.items()})\n\n        model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        if model_args.tune_mm_mlp_adapter:\n            model.requires_grad_(False)\n            for p in mm_projector.parameters():\n                p.requires_grad = True\n            for p in obj_projector.parameters():\n                p.requires_grad = True\n        model.config.mm_use_im_start_end = model_args.mm_use_im_start_end\n        data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n        vision_config.use_im_start_end = model_args.mm_use_im_start_end\n        if model_args.mm_use_im_start_end:\n            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_OBJECT_START_TOKEN,DEFAULT_OBJECT_END_TOKEN], special_tokens=True)\n            model.resize_token_embeddings(len(tokenizer))\n            vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n\n            if num_new_tokens > 0:\n                input_embeddings = model.get_input_embeddings().weight.data\n                output_embeddings = model.get_output_embeddings().weight.data\n\n                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n                    dim=0, keepdim=True)\n\n                input_embeddings[-num_new_tokens:] = input_embeddings_avg\n                output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n            if model_args.tune_mm_mlp_adapter:\n                model.model.orig_embeds_params = [model.get_input_embeddings().weight.data.clone().to(device=training_args.device)]\n                for p in model.get_input_embeddings().parameters():\n                    p.requires_grad = True\n                for p in model.get_output_embeddings().parameters():\n                    p.requires_grad = False\n\n            if model_args.pretrain_mm_mlp_adapter:\n                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')\n                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']\n                assert input_embeddings.shape == embed_tokens_weight.shape\n                assert num_new_tokens == 2\n                input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]\n\n        vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    trainer = Trainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n    safe_save_model_for_hf_trainer(trainer=trainer,\n                                   output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/model/semsam/language/loss.py",
    "content": "import pickle\nfrom distutils import log\n\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\n\nfrom einops import rearrange, repeat\nfrom timm.loss import SoftTargetCrossEntropy\n\nsoft_cross_entropy = SoftTargetCrossEntropy()\n\ndef is_dist_initialized():\n    return torch.distributed.is_initialized()\n\ndef get_world_size():\n    if is_dist_initialized():\n        return torch.distributed.get_world_size()\n    return 1\n\ndef get_rank():\n    if is_dist_initialized():\n        return dist.get_rank()\n    return 0\n\ndef all_gather_grad(x):\n    if get_world_size() > 1:\n        all_x = [torch.zeros_like(x) for _ in range(get_world_size())]\n        torch.distributed.all_gather(all_x, x)\n        all_x[torch.distributed.get_rank()] = x\n        x = torch.cat(all_x, dim=0)\n    return x\n\ndef vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):\n    \"\"\"\n    Args:\n        image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256\n        text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256\n\n    Returns:\n    \"\"\"\n    # [B, L1, C], L1 = 1\n    # image_feat = F.normalize(image_feat, dim=-1)\n    # [B, L2, C]\n    # text_feat = F.normalize(text_feat, dim=-1)\n    # HACK: normalize outside\n    \n    # [B, L1, L2]\n    dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')    \n    # [B, L2, L1]\n    dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')\n\n    batch = image_feat.shape[0]\n    img_len = image_feat.shape[1]\n    text_len = text_feat.shape[1]\n    # [B, L1, L2]\n    pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')\n    # [B, L2, L1]\n    pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')\n\n    image_x = rearrange(image_feat, 'b l c -> (b l) c')\n    text_x = rearrange(text_feat, 'b l c -> (b l) c')\n\n    logits_per_img = image_x @ all_gather_grad(text_x).t()\n    logits_per_text = text_x @ all_gather_grad(image_x).t()\n\n    # get label globally\n    # [B, L1, B, L2, W]\n    labels_per_img = F.one_hot(\n        torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),\n        num_classes=get_world_size()).to(image_x.dtype)\n    labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(\n        torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')\n    # [BxL1, WxBxL2]\n    labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')\n    # [B, L2, B, L1, W]\n    labels_per_text = F.one_hot(\n        torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),\n        num_classes=get_world_size()).to(text_x.dtype)\n    labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(\n        torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')\n    # [BxL2, WxBxL1]\n    labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')\n\n    logit_scale = temperature.exp().clamp(max=100)\n\n    loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)\n    loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)\n\n    loss = 0.5 * (loss_img + loss_text)\n    return loss\n\ndef vl_contrastive_loss(image_feat, text_feat, temperature=1):\n    # if image_id or text_id is None, it should be None across all GPUs\n    # image_feat = F.normalize(image_feat, dim=1)\n    # text_feat = F.normalize(text_feat, dim=1)\n    # handle normalization outside\n\n    # add the following 4 lines\n    image_feat = all_gather_grad(image_feat)\n    text_feat = all_gather_grad(text_feat)\n    \n    logits = torch.matmul(image_feat, text_feat.t())\n    logit_scale = temperature.exp().clamp(max=100)\n\n    gt = torch.arange(logits.shape[0], device=logits.device)\n    loss1 = F.cross_entropy(logit_scale * logits, gt)\n    loss2 = F.cross_entropy(logit_scale * logits.t(), gt)\n    return (loss1 + loss2) / 2 # scale it up by the number of GPUs\n\n\ndef all_gather_pickle(data, device):\n    \"\"\"\n    Run all_gather on arbitrary picklable data (not necessarily tensors)\n    Args:\n        data: any picklable object\n    Returns:\n        list[data]: list of data gathered from each rank\n    \"\"\"\n    world_size = get_world_size()\n    if world_size == 1:\n        return [data]\n\n    # serialized to a Tensor\n    buffer = pickle.dumps(data)\n    storage = torch.ByteStorage.from_buffer(buffer)\n    tensor = torch.ByteTensor(storage).to(device)\n\n    # obtain Tensor size of each rank\n    local_size = torch.LongTensor([tensor.numel()]).cuda()\n    size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]\n    dist.all_gather(size_list, local_size)\n    size_list = [int(size.item()) for size in size_list]\n    max_size = max(size_list)\n\n    # receiving Tensor from all ranks\n    # we pad the tensor because torch all_gather does not support\n    # gathering tensors of different shapes\n    tensor_list = []\n    for _ in size_list:\n        tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())\n    if local_size != max_size:\n        padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()\n        tensor = torch.cat((tensor, padding), dim=0)\n    dist.all_gather(tensor_list, tensor)\n\n    data_list = []\n    for size, tensor in zip(size_list, tensor_list):\n        buffer = tensor.cpu().numpy().tobytes()[:size]\n        data_list.append(pickle.loads(buffer))\n\n    return data_list\n\ndef all_gather_arbitary_tensor(tensor):\n    if get_world_size() > 1:\n        device = tensor.device\n        tensor_batch = all_gather_pickle(tensor.cpu(), device)\n        tensor_batch = [x.to(device) for x in tensor_batch]\n        tensor_batch[torch.distributed.get_rank()] = tensor\n        tensor_batch = torch.cat(tensor_batch, dim=0)\n    else:\n        tensor_batch = tensor\n    return tensor_batch\n\ndef ql_contrastive_loss(image_feat, text_feat, temperature=1):\n    # add the following 4 lines\n    image_feat = all_gather_arbitary_tensor(image_feat)\n    text_feat = all_gather_arbitary_tensor(text_feat)\n\n    logits = torch.matmul(image_feat, text_feat.t())\n    logit_scale = temperature.exp().clamp(max=100)\n\n    gt = torch.arange(logits.shape[0], device=logits.device)\n    loss1 = F.cross_entropy(logit_scale * logits, gt)\n    loss2 = F.cross_entropy(logit_scale * logits.t(), gt)\n    return (loss1 + loss2) / 2 # scale it up by the number of GPUs\n\ndef vl_similarity(image_feat, text_feat, temperature=1):\n    # Only support single GPU for now.\n    logits = torch.matmul(image_feat, text_feat.t())\n    logits = temperature.exp().clamp(max=100) * logits\n    return logits\n\ndef ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):\n    # add the following 4 lines\n    image_feat = all_gather_arbitary_tensor(image_feat)\n    text_feat = all_gather_arbitary_tensor(text_feat)\n\n    text_hash_batch = all_gather_pickle(text_hash, text_feat.device)\n    text_hash_all = torch.cat(text_hash_batch)\n    \n    text_hash_all_unique = torch.unique(text_hash_all).tolist()\n    gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)\n    text_hash_all = text_hash_all.tolist()\n    text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])\n\n    for idx, txt in enumerate(text_hash_all):\n        gt[idx][text_hash_all_unique.index(txt)] = 1\n    \n    logits = torch.matmul(image_feat, text_feat_unique.t())\n    logits = logits*temperature.exp().clamp(max=100)\n    \n    loss_img = soft_cross_entropy(logits, gt)\n    loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))\n\n    loss = 0.7 * loss_img + 0.3 * loss_text\n    return loss\n\ndef image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):\n    # add the following 4 lines\n    image_feat = all_gather_grad(image_feat_inp.contiguous())\n    text_feat = all_gather_grad(text_feat_inp.contiguous())\n\n    image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)\n    text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)\n\n    temperature = lang_enc.logit_scale\n    logits = torch.matmul(image_feat, text_feat.t())\n    logit_scale = temperature.exp().clamp(max=100)\n\n    gt = torch.arange(logits.shape[0], device=logits.device)\n    loss1 = F.cross_entropy(logit_scale * logits, gt)\n    loss2 = F.cross_entropy(logit_scale * logits.t(), gt)\n\n    return (loss1 + loss2) / 2 # scale it up by the number of GPUs"
  },
  {
    "path": "llava/model/semsam/language/misc.py",
    "content": "import random\n\nimport torch\nimport nltk\nnltk.data.path.append('/mnt/data/nltk_data')\nimport numpy as np\n\nfrom utils.constants import IMAGENET_DEFAULT_TEMPLATES\n\n\ndef vl_similarity(image_feat, text_feat, temperature=1):\n    # Only support single GPU for now.\n    logits = torch.matmul(image_feat, text_feat.t())\n    logits = temperature.exp().clamp(max=100) * logits\n    return logits\n\ndef get_tag(tokenized, tags):\n    if not isinstance(tags, (list, tuple)):\n        tags = [tags]\n    ret = []\n    for (word, pos) in nltk.pos_tag(tokenized):\n        for tag in tags:\n            if pos == tag:\n                ret.append(word)\n    return ret\n\ndef get_noun_phrase(tokenized):\n    # Taken from Su Nam Kim Paper...\n    grammar = r\"\"\"\n        NBAR:\n            {<NN.*|JJ>*<NN.*>}  # Nouns and Adjectives, terminated with Nouns\n\n        NP:\n            {<NBAR>}\n            {<NBAR><IN><NBAR>}  # Above, connected with in/of/etc...\n    \"\"\"\n    chunker = nltk.RegexpParser(grammar)\n\n    chunked = chunker.parse(nltk.pos_tag(tokenized))\n    continuous_chunk = []\n    current_chunk = []\n\n    for subtree in chunked:\n        if isinstance(subtree, nltk.Tree):\n            current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))\n        elif current_chunk:\n            named_entity = ' '.join(current_chunk)\n            if named_entity not in continuous_chunk:\n                continuous_chunk.append(named_entity)\n                current_chunk = []\n        else:\n            continue\n\n    return continuous_chunk\n\ndef text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):\n    tokenized = nltk.word_tokenize(text)\n    \n    if random.random() >= phrase_prob:\n        nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])\n    else:\n        nouns = get_noun_phrase(tokenized)\n\n\n    prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]\n    \n    if append_text:\n        prompt_texts += [text]\n        nouns += [text]\n    \n    return prompt_texts, nouns"
  },
  {
    "path": "llava/model/semsam/language/modeling_llama_os.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch LLaMA model.\"\"\"\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings\nfrom transformers.models.llama.configuration_llama import LlamaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"LlamaConfig\"\n\n\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))\n    mask_cond = torch.arange(mask.size(-1))\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n        # convert into half-precision if necessary\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\nclass LlamaRotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.\n        if seq_len > self.max_seq_len_cached:\n            self.max_seq_len_cached = seq_len\n            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)\n            freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            # Different from paper, but it uses a different permutation in order to obtain the same calculation\n            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n            self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n            self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n        return (\n            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):\n    cos = cos[..., offset : q.shape[-2] + offset, :]\n    sin = sin[..., offset : q.shape[-2] + offset, :]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        self.head_dim = hidden_size // num_heads\n\n        if (self.head_dim * num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {num_heads}).\"\n            )\n        self.q_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.k_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.v_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(\n            num_heads * self.head_dim,\n            hidden_size,\n            bias=False,\n        )\n        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        offset = 0\n        if past_key_value is not None:\n            offset = past_key_value[0].shape[-2]\n            kv_seq_len += offset\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)\n        # [bsz, nh, t, hd]\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(\n            hidden_size=self.hidden_size,\n            num_heads=config.num_attention_heads,\n        )\n        self.mlp = LlamaMLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            past_key_value=past_key_value,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, LlamaModel):\n            module.gradient_checkpointing = value\n\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaModel(LlamaPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        self.embed_tokens_out=None\n        # Initialize weights and apply final processing\n        self.post_init()\n\n        if hasattr(config, \"mm_vision_tower\"):\n            from transformers import CLIPVisionModel\n            self.vision_tower = [None]\n\n        if hasattr(config, \"use_mm_proj\"):\n            self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def get_output_embeddings(self):\n        return self.embed_tokens_out\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length\n            ).to(inputs_embeds.device)\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    # def find_pattern(self,src,pattern):\n    def find_pattern_list(self, pattern, src):\n        assert len(pattern) <= len(src)\n        i = len(pattern)-1\n        while True:\n            match = True\n            for j in range(len(pattern)):\n                if int(src[i - j]) != pattern[len(pattern) - 1 - j]:\n                    match = False\n                    break\n            if match:\n                return i\n            i += 1\n            if i >= len(src) - 1:\n                return -1\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n        im_feats=None, obj_feats=None,fp16=True,mm_projector=None,obj_projector=None,obj_projector_out=None,obj_num=True,question_ref_queries=None,**kwargs\n\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        # HACK: replace back original embeddings for pretraining\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            orig_embeds_params_in = orig_embeds_params[0]\n            orig_embeds_params_out = orig_embeds_params[1]\n            st=self.tokenizer.encode(\"<im_start>\")[1]\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:st] = orig_embeds_params_in[:st].data\n                # if self.tokenizer.decode([len(self.tokenizer)-1])=='<seg>':\n                self.get_output_embeddings().weight[:st] = orig_embeds_params_out[:st].data\n\n            # if fp16:\n            #     self.get_input_embeddings().weight=self.get_input_embeddings().weight.to(torch.float16)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if (input_ids.shape[1] != 1 or self.training) and im_feats is not None:\n            if mm_projector is None:\n                mm_projector=self.mm_projector\n            if obj_projector is None:\n                obj_projector=self.obj_projector\n            if fp16:\n                image_features = mm_projector(im_feats.to(torch.float16))\n                obj_features=[]\n                for feat in obj_feats:\n                    if feat is not None:\n                        obj_features.append(obj_projector(feat.to(torch.float16)))\n                    else:\n                        obj_features.append(None)\n                if question_ref_queries is not None:\n                    for i,q in enumerate(question_ref_queries):\n                        if q is not None:\n                            question_ref_queries[i]=obj_projector(q.to(torch.float16))\n                    # question_ref_queries=[obj_projector(q.to(torch.float16)) for q in question_ref_queries if q is not None]\n            else:\n                image_features = mm_projector(im_feats)\n                obj_features = []\n                for feat in obj_feats:\n                    if feat is not None:\n                        obj_features.append(obj_projector(feat))\n                    else:\n                        obj_features.append(None)\n                # import pdb;pdb.set_trace()\n                if question_ref_queries is not None:\n                    for i,q in enumerate(question_ref_queries):\n                        if q is not None:\n                            question_ref_queries[i]=obj_projector(q)\n                    # question_ref_queries=[obj_projector(q) for q in question_ref_queries if q is not None]\n\n            new_input_embeds = []\n            cur_image_idx = 0\n            for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):\n                cur_image_features = image_features[cur_image_idx]\n                if (cur_input_ids == self.im_start_token).sum() != (cur_input_ids == self.im_end_token).sum():\n                    raise ValueError(\"The number of im_start_token and im_end_token should be the same\")\n                image_start_tokens = torch.where(cur_input_ids == self.im_start_token)[0]\n                assert len(image_start_tokens)==1\n                image_start_token_pos=image_start_tokens[0]\n                # for image_start_token_pos in image_start_tokens: #currently only one image\n                cur_image_features = image_features[cur_image_idx]\n                num_patches = cur_image_features.shape[0]\n                if cur_input_ids[image_start_token_pos + num_patches + 1] != self.im_end_token:\n                    raise ValueError(\"Seems that the image is cut.\")\n                cur_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)\n                ############## OBJ\n                cur_obj_features = obj_features[cur_image_idx]\n                if cur_obj_features is not None:\n                    if (cur_input_ids == self.obj_start_token).sum() != (cur_input_ids == self.obj_end_token).sum():\n                        raise ValueError(\"The number of obj_start_token and obj_end_token should be the same\")\n                    obj_start_tokens = torch.where(cur_input_ids == self.obj_start_token)[0]\n                    assert len(obj_start_tokens)==1\n                    obj_start_token_pos=obj_start_tokens[0]\n                    obj_end_tokens = torch.where(cur_input_ids == self.obj_end_token)[0]\n                    assert len(obj_end_tokens) == 1\n                    obj_end_token_pos = obj_end_tokens[0]\n                    # for image_start_token_pos in image_start_tokens: #currently only one image\n                    num_patches = cur_obj_features.shape[0]\n                    if obj_num:\n                        starts=[]\n                        for i_obj in range(num_patches):\n                            mark=self.tokenizer.encode(f\"{i_obj}.\")[1:]\n                            start_i=self.find_pattern_list(mark,cur_input_ids)+1\n                            assert cur_input_ids[start_i]==self.tokenizer.encode(\"<obj_patch>\")[1]\n                            if start_i!=-1 and start_i>obj_start_token_pos and start_i<obj_end_token_pos:\n                                starts.append(start_i)\n                        cur_input_embeds[starts]=cur_obj_features.to(cur_input_embeds.dtype)\n                        cur_new_input_embeds=cur_input_embeds\n                    else:\n                        if cur_input_ids[obj_start_token_pos + num_patches + 1] != self.obj_end_token:\n                            raise ValueError(\"Seems that the objs are cut.\")\n                        cur_new_input_embeds = torch.cat((cur_input_embeds[:obj_start_token_pos+1], cur_obj_features, cur_input_embeds[obj_start_token_pos + num_patches + 1:]), dim=0)\n                    if question_ref_queries is not None:\n                        # import pdb;pdb.set_trace()\n                        if question_ref_queries[cur_image_idx] is not None:\n                            obj_patch_tokens = torch.where(cur_input_ids == self.obj_patch_token)[0]\n                            obj_patch_tokens=[int(i) for i in obj_patch_tokens if i>obj_end_tokens[0]]\n                            if len(obj_patch_tokens)>0:\n                                cur_new_input_embeds[obj_patch_tokens]=question_ref_queries[cur_image_idx].to(cur_input_embeds.dtype)\n                else:\n                    cur_new_input_embeds=cur_input_embeds\n\n                cur_image_idx += 1\n                new_input_embeds.append(cur_new_input_embeds)\n\n            inputs_embeds = torch.stack(new_input_embeds, dim=0)\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n            )\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n        )\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass LlamaForCausalLM(LlamaPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n        im_feats=None, obj_feats=None,fp16=True,tokenizer=None,training=True,return_hidden=False,reduce_loss=True,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional\n                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        self.model.tokenizer=tokenizer\n        if not training:\n            self.model.training=False\n        self.model.embed_tokens_out=self.get_output_embeddings()\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            images=images,\n            im_feats=im_feats, obj_feats=obj_feats,fp16=fp16,**kwargs\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        loss_ls=[]\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            if not reduce_loss:\n                shift_logits = shift_logits.view(shift_logits.shape[0], -1,self.config.vocab_size)\n                shift_labels = shift_labels.view(shift_logits.shape[0],-1)\n                # Enable model/pipeline parallelism\n                shift_labels = shift_labels.to(shift_logits.device)\n                for shift_logits_,shift_labels_ in zip(shift_logits, shift_labels):\n                    loss_ls.append(loss_fct(shift_logits_, shift_labels_))\n                loss=loss_ls\n            else:\n                shift_logits = shift_logits.view(-1, self.config.vocab_size)\n                shift_labels = shift_labels.view(-1)\n                # Enable model/pipeline parallelism\n                shift_labels = shift_labels.to(shift_logits.device)\n                loss=loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.last_hidden_state,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n                \"images\": kwargs.get(\"images\", None),\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The LLaMa Model transformer with a sequence classification head on top (linear layer).\n\n    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForSequenceClassification(LlamaPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = LlamaModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "llava/model/semsam/language/registry.py",
    "content": "_model_entrypoints = {}\n\ndef register_model(fn):\n    module_name_split = fn.__module__.split('.')\n    model_name = module_name_split[-1]\n    _model_entrypoints[model_name] = fn\n    return fn\n\ndef model_entrypoints(model_name):\n    return _model_entrypoints[model_name]\n\ndef is_model(model_name):\n    return model_name in _model_entrypoints"
  },
  {
    "path": "llava/model/semsam/language/vlpencoder.py",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models.layers import trunc_normal_\n\nfrom .registry import register_model\nfrom ..utils import configurable\nfrom .LangEncoder import build_tokenizer, build_lang_encoder\nfrom utils.prompt_engineering import prompt_engineering, get_prompt_templates\n\n\nclass LanguageEncoder(nn.Module):\n\n    @configurable\n    def __init__(\n        self,\n        tokenizer,\n        tokenizer_type,\n        lang_encoder,\n        lang_projection,\n        max_token_num,\n        queue_operator,\n    ):\n        super().__init__()\n        # seg\n        self.tokenizer = tokenizer\n        self.tokenizer_type = tokenizer_type\n        self.lang_encoder = lang_encoder\n        self.lang_proj = lang_projection\n        self.max_token_num = max_token_num\n        self.logit_scale = nn.Parameter(torch.ones([]))\n        \n        # captioning & retrieval\n        for key, value in queue_operator.items():\n            self.register_buffer(key, value)\n            \n\n    @classmethod\n    def from_config(cls, cfg):\n        # build up text encoder for seg\n        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])\n        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']\n        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])\n        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']\n        \n        dim_lang = cfg['MODEL']['TEXT']['WIDTH']\n        dim_projection = cfg['MODEL']['DIM_PROJ']\n        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))\n        trunc_normal_(lang_projection, std=.02)\n\n        # tested not working better      \n        queue_operator = {}\n\n        return {\n            \"tokenizer\": tokenizer,\n            \"tokenizer_type\": tokenizer_type,\n            \"lang_encoder\": lang_encoder,\n            \"lang_projection\": lang_projection,\n            \"max_token_num\": max_token_num,\n            \"queue_operator\": queue_operator,\n        }\n\n    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):\n        if not is_eval:\n            if prompt:\n                # randomly sample one template\n                arbitary_concepts = [\n                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \\\n                    for label in range(len(class_names))\n                ]\n                if add_bgd:\n                    arbitary_concepts.append(\"A background in coco.\")\n            else:\n                arbitary_concepts = class_names\n            \n            input_ids = []\n            attention_masks = []\n            for txt in arbitary_concepts:\n                tokens = self.tokenizer(\n                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                )\n                tokens['input_ids'].squeeze_()\n                tokens['attention_mask'].squeeze_()\n\n                input_ids.append(tokens['input_ids'])\n                attention_masks.append(tokens['attention_mask'])\n\n            arbitary_tokens = torch.stack(input_ids)\n            arbitary_attention_masks = torch.stack(attention_masks)\n\n            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)\n            setattr(self, '{}_text_embeddings'.format(name), text_emb)\n        else:\n            with torch.no_grad():\n                def extract_mean_emb(txts):\n                    tokens = self.tokenizer(\n                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n                    )\n                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)\n                    clss_embedding = clss_embedding.mean(dim=0)\n                    clss_embedding /= clss_embedding.norm()\n                    return clss_embedding\n\n                templates = get_prompt_templates()\n                clss_embeddings = []\n                if prompt:\n                    for clss in class_names:\n                        txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]\n                        clss_embeddings.append(extract_mean_emb(txts))\n                else:\n                    clss_embeddings.append(extract_mean_emb(class_names))\n\n                if add_bgd:\n                    txts = [\"A background in coco.\"]\n                    clss_embeddings.append(extract_mean_emb(txts))\n\n                text_emb = torch.stack(clss_embeddings, dim=0)\n                setattr(self, '{}_text_embeddings'.format(name), text_emb)\n\n    def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):\n        if not token:\n            tokens = self.tokenizer(\n                txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'\n            )\n            tokens = {key: value.cuda() for key, value in tokens.items()}\n        else:\n            tokens = txts\n        token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)\n        ret = {\"tokens\": tokens,\n                \"token_emb\": token_emb,\n                \"class_emb\": class_emb,}\n        setattr(self, '{}_token_embeddings'.format(name), ret)\n        return ret\n\n    def forward_language(self, texts, norm=True):\n        x = self.lang_encoder(*texts)\n        x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            x = x[:, 0]\n\n        x = x @ self.lang_proj\n        if norm:\n            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)\n        return x\n    \n    def forward_language_token(self, texts, norm=False):\n        x = self.lang_encoder(*texts)\n        token_x = x['last_hidden_state']\n\n        if self.tokenizer_type == 'clip':\n            class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]\n        else:\n            class_x = token_x[:, 0]\n\n        class_x = class_x @ self.lang_proj\n        token_x = token_x @ self.lang_proj\n\n        if norm:\n            class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)\n            token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)\n\n        return token_x, class_x\n    \n    def compute_similarity(self, v_emb, name='default', fake=False):\n        if fake:\n            return None\n        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)\n        t_emb = getattr(self, '{}_text_embeddings'.format(name))\n        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)\n        return output\n\n\n@register_model\ndef get_language_model(cfg, **kwargs):\n    return LanguageEncoder(cfg)"
  },
  {
    "path": "llava/model/semsam/modules/__init__.py",
    "content": "from .point_features import *\nfrom .position_encoding import *\nfrom .postprocessing import *\nfrom .attention import *\nfrom .matcher import *\nfrom .criterion_id_llm import *\nfrom .hooks import HookBase"
  },
  {
    "path": "llava/model/semsam/modules/attention.py",
    "content": "import warnings\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nfrom torch.nn.init import constant_, xavier_normal_, xavier_uniform_\nfrom torch.nn.parameter import Parameter\nfrom torch.overrides import has_torch_function, handle_torch_function\nfrom torch.nn.functional import pad, linear, softmax, dropout\n\n\ndef multi_head_attention_forward(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    embed_dim_to_check: int,\n    num_heads: int,\n    in_proj_weight: Tensor,\n    in_proj_bias: Tensor,\n    bias_k: Optional[Tensor],\n    bias_v: Optional[Tensor],\n    add_zero_attn: bool,\n    dropout_p: float,\n    out_proj_weight: Tensor,\n    out_proj_bias: Tensor,\n    training: bool = True,\n    key_padding_mask: Optional[Tensor] = None,\n    need_weights: bool = True,\n    attn_mask: Optional[Tensor] = None,\n    use_separate_proj_weight: bool = False,\n    q_proj_weight: Optional[Tensor] = None,\n    k_proj_weight: Optional[Tensor] = None,\n    v_proj_weight: Optional[Tensor] = None,\n    static_k: Optional[Tensor] = None,\n    static_v: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"\n    Args:\n        query, key, value: map a query and a set of key-value pairs to an output.\n            See \"Attention Is All You Need\" for more details.\n        embed_dim_to_check: total dimension of the model.\n        num_heads: parallel attention heads.\n        in_proj_weight, in_proj_bias: input projection weight and bias.\n        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.\n        add_zero_attn: add a new batch of zeros to the key and\n                       value sequences at dim=1.\n        dropout_p: probability of an element to be zeroed.\n        out_proj_weight, out_proj_bias: the output projection weight and bias.\n        training: apply dropout if is ``True``.\n        key_padding_mask: if provided, specified padding elements in the key will\n            be ignored by the attention. This is an binary mask. When the value is True,\n            the corresponding value on the attention layer will be filled with -inf.\n        need_weights: output attn_output_weights.\n        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n            the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n        use_separate_proj_weight: the function accept the proj. weights for query, key,\n            and value in different forms. If false, in_proj_weight will be used, which is\n            a combination of q_proj_weight, k_proj_weight, v_proj_weight.\n        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.\n        static_k, static_v: static key and value used for attention operators.\n\n\n    Shape:\n        Inputs:\n        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions\n          will be unchanged. If a BoolTensor is provided, the positions with the\n          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.\n          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,\n          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked\n          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n          is provided, it will be added to the attention weight.\n        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n\n        Outputs:\n        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n          E is the embedding dimension.\n        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n          L is the target sequence length, S is the source sequence length.\n    \"\"\"\n    tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)\n    if has_torch_function(tens_ops):\n        return handle_torch_function(\n            multi_head_attention_forward,\n            tens_ops,\n            query,\n            key,\n            value,\n            embed_dim_to_check,\n            num_heads,\n            in_proj_weight,\n            in_proj_bias,\n            bias_k,\n            bias_v,\n            add_zero_attn,\n            dropout_p,\n            out_proj_weight,\n            out_proj_bias,\n            training=training,\n            key_padding_mask=key_padding_mask,\n            need_weights=need_weights,\n            attn_mask=attn_mask,\n            use_separate_proj_weight=use_separate_proj_weight,\n            q_proj_weight=q_proj_weight,\n            k_proj_weight=k_proj_weight,\n            v_proj_weight=v_proj_weight,\n            static_k=static_k,\n            static_v=static_v,\n        )\n    tgt_len, bsz, embed_dim = query.size()\n    assert embed_dim == embed_dim_to_check\n    # allow MHA to have different sizes for the feature dimension\n    assert key.size(0) == value.size(0) and key.size(1) == value.size(1)\n\n    head_dim = embed_dim // num_heads\n    assert head_dim * num_heads == embed_dim, \"embed_dim must be divisible by num_heads\"\n    scaling = float(head_dim) ** -0.5\n\n    if not use_separate_proj_weight:\n        if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):\n            # self-attention\n            q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)\n\n        elif key is value or torch.equal(key, value):\n            # encoder-decoder attention\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = 0\n            _end = embed_dim\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            q = linear(query, _w, _b)\n\n            if key is None:\n                assert value is None\n                k = None\n                v = None\n            else:\n\n                # This is inline in_proj function with in_proj_weight and in_proj_bias\n                _b = in_proj_bias\n                _start = embed_dim\n                _end = None\n                _w = in_proj_weight[_start:, :]\n                if _b is not None:\n                    _b = _b[_start:]\n                k, v = linear(key, _w, _b).chunk(2, dim=-1)\n\n        else:\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = 0\n            _end = embed_dim\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            q = linear(query, _w, _b)\n\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = embed_dim\n            _end = embed_dim * 2\n            _w = in_proj_weight[_start:_end, :]\n            if _b is not None:\n                _b = _b[_start:_end]\n            k = linear(key, _w, _b)\n\n            # This is inline in_proj function with in_proj_weight and in_proj_bias\n            _b = in_proj_bias\n            _start = embed_dim * 2\n            _end = None\n            _w = in_proj_weight[_start:, :]\n            if _b is not None:\n                _b = _b[_start:]\n            v = linear(value, _w, _b)\n    else:\n        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)\n        len1, len2 = q_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == query.size(-1)\n\n        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)\n        len1, len2 = k_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == key.size(-1)\n\n        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)\n        len1, len2 = v_proj_weight_non_opt.size()\n        assert len1 == embed_dim and len2 == value.size(-1)\n\n        if in_proj_bias is not None:\n            q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])\n            k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])\n            v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])\n        else:\n            q = linear(query, q_proj_weight_non_opt, in_proj_bias)\n            k = linear(key, k_proj_weight_non_opt, in_proj_bias)\n            v = linear(value, v_proj_weight_non_opt, in_proj_bias)\n    q = q * scaling\n\n    if attn_mask is not None:\n        assert (\n            attn_mask.dtype == torch.float32\n            or attn_mask.dtype == torch.float64\n            or attn_mask.dtype == torch.float16\n            or attn_mask.dtype == torch.uint8\n            or attn_mask.dtype == torch.bool\n        ), \"Only float, byte, and bool types are supported for attn_mask, not {}\".format(attn_mask.dtype)\n        if attn_mask.dtype == torch.uint8:\n            warnings.warn(\"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n            attn_mask = attn_mask.to(torch.bool)\n\n        if attn_mask.dim() == 2:\n            attn_mask = attn_mask.unsqueeze(0)\n            if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:\n                raise RuntimeError(\"The size of the 2D attn_mask is not correct.\")\n        elif attn_mask.dim() == 3:\n            if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:\n                raise RuntimeError(\"The size of the 3D attn_mask is not correct.\")\n        else:\n            raise RuntimeError(\"attn_mask's dimension {} is not supported\".format(attn_mask.dim()))\n        # attn_mask's dim is 3 now.\n\n    # convert ByteTensor key_padding_mask to bool\n    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:\n        warnings.warn(\n            \"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\"\n        )\n        key_padding_mask = key_padding_mask.to(torch.bool)\n\n    if bias_k is not None and bias_v is not None:\n        if static_k is None and static_v is None:\n            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])\n            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])\n            if attn_mask is not None:\n                attn_mask = pad(attn_mask, (0, 1))\n            if key_padding_mask is not None:\n                key_padding_mask = pad(key_padding_mask, (0, 1))\n        else:\n            assert static_k is None, \"bias cannot be added to static key.\"\n            assert static_v is None, \"bias cannot be added to static value.\"\n    else:\n        assert bias_k is None\n        assert bias_v is None\n\n    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)\n    if k is not None:\n        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n    if v is not None:\n        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n\n    if static_k is not None:\n        assert static_k.size(0) == bsz * num_heads\n        assert static_k.size(2) == head_dim\n        k = static_k\n\n    if static_v is not None:\n        assert static_v.size(0) == bsz * num_heads\n        assert static_v.size(2) == head_dim\n        v = static_v\n\n    src_len = k.size(1)\n\n    if key_padding_mask is not None:\n        # assert key_padding_mask.size(0) == bsz\n        assert key_padding_mask.size(1) == src_len\n\n    if add_zero_attn:\n        src_len += 1\n        k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)\n        v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)\n        if attn_mask is not None:\n            attn_mask = pad(attn_mask, (0, 1))\n        if key_padding_mask is not None:\n            key_padding_mask = pad(key_padding_mask, (0, 1))\n\n    attn_output_weights = torch.bmm(q, k.transpose(1, 2))\n    assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]\n\n    if attn_mask is not None:\n        if attn_mask.dtype == torch.bool:\n            attn_output_weights.masked_fill_(attn_mask, float(\"-inf\"))\n        else:\n            attn_output_weights += attn_mask\n\n    if key_padding_mask is not None:\n        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n        attn_output_weights = attn_output_weights.masked_fill(\n            key_padding_mask.unsqueeze(1),\n            float(\"-inf\"),\n        )\n        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)\n\n    attn_output_weights = softmax(attn_output_weights, dim=-1)\n    attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)\n\n    attn_output = torch.bmm(attn_output_weights, v)\n    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]\n    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)\n    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)\n\n    if need_weights:\n        # average attention weights over heads\n        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n        return attn_output, attn_output_weights.sum(dim=1) / num_heads\n    else:\n        return attn_output, None\n\n\n# This class exists solely for Transformer; it has an annotation stating\n# that bias is never None, which appeases TorchScript\nclass _LinearWithBias(nn.Linear):\n    bias: Tensor  # type: ignore\n\n    def __init__(self, in_features: int, out_features: int) -> None:\n        super().__init__(in_features, out_features, bias=True)  # type: ignore\n\n\nclass MultiheadAttention(nn.Module):\n    r\"\"\"Allows the model to jointly attend to information\n    from different representation subspaces.\n    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_\n\n    .. math::\n        \\text{MultiHead}(Q, K, V) = \\text{Concat}(head_1,\\dots,head_h)W^O\n\n    where :math:`head_i = \\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.\n\n    Args:\n        embed_dim: total dimension of the model.\n        num_heads: parallel attention heads.\n        dropout: a Dropout layer on attn_output_weights. Default: 0.0.\n        bias: add bias as module parameter. Default: True.\n        add_bias_kv: add bias to the key and value sequences at dim=0.\n        add_zero_attn: add a new batch of zeros to the key and\n                       value sequences at dim=1.\n        kdim: total number of features in key. Default: None.\n        vdim: total number of features in value. Default: None.\n\n    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set\n    to :attr:`embed_dim` such that query, key, and value have the same\n    number of features.\n\n    Examples::\n\n        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)\n        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)\n    \"\"\"\n    bias_k: Optional[torch.Tensor]\n    bias_v: Optional[torch.Tensor]\n\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):\n        super(MultiheadAttention, self).__init__()\n        self.embed_dim = embed_dim\n        self.kdim = kdim if kdim is not None else embed_dim\n        self.vdim = vdim if vdim is not None else embed_dim\n        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim\n\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n\n        if self._qkv_same_embed_dim is False:\n            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))\n            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))\n            self.register_parameter('in_proj_weight', None)\n        else:\n            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))\n            self.register_parameter('q_proj_weight', None)\n            self.register_parameter('k_proj_weight', None)\n            self.register_parameter('v_proj_weight', None)\n\n        if bias:\n            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n        else:\n            self.register_parameter('in_proj_bias', None)\n        self.out_proj = _LinearWithBias(embed_dim, embed_dim)\n\n        if add_bias_kv:\n            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))\n            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))\n        else:\n            self.bias_k = self.bias_v = None\n\n        self.add_zero_attn = add_zero_attn\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        if self._qkv_same_embed_dim:\n            xavier_uniform_(self.in_proj_weight)\n        else:\n            xavier_uniform_(self.q_proj_weight)\n            xavier_uniform_(self.k_proj_weight)\n            xavier_uniform_(self.v_proj_weight)\n\n        if self.in_proj_bias is not None:\n            constant_(self.in_proj_bias, 0.)\n            constant_(self.out_proj.bias, 0.)\n        if self.bias_k is not None:\n            xavier_normal_(self.bias_k)\n        if self.bias_v is not None:\n            xavier_normal_(self.bias_v)\n\n    def __setstate__(self, state):\n        # Support loading old MultiheadAttention checkpoints generated by v1.1.0\n        if '_qkv_same_embed_dim' not in state:\n            state['_qkv_same_embed_dim'] = True\n\n        super(MultiheadAttention, self).__setstate__(state)\n\n    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,\n                need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:\n        r\"\"\"\n    Args:\n        query, key, value: map a query and a set of key-value pairs to an output.\n            See \"Attention Is All You Need\" for more details.\n        key_padding_mask: if provided, specified padding elements in the key will\n            be ignored by the attention. When given a binary mask and a value is True,\n            the corresponding value on the attention layer will be ignored. When given\n            a byte mask and a value is non-zero, the corresponding value on the attention\n            layer will be ignored\n        need_weights: output attn_output_weights.\n        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n            the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n\n    Shapes for inputs:\n        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n          the embedding dimension.\n        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n          If a ByteTensor is provided, the non-zero positions will be ignored while the position\n          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the\n          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n        - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the\n          source sequence length.\n\n          If a 3D mask: :math:`(N\\cdot\\text{num\\_heads}, L, S)` where N is the batch size, L is the target sequence\n          length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend\n          the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n          is provided, it will be added to the attention weight.\n\n    Shapes for outputs:\n        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n          E is the embedding dimension.\n        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n          L is the target sequence length, S is the source sequence length.\n        \"\"\"\n        if not self._qkv_same_embed_dim:\n            return multi_head_attention_forward(\n                query, key, value, self.embed_dim, self.num_heads,\n                self.in_proj_weight, self.in_proj_bias,\n                self.bias_k, self.bias_v, self.add_zero_attn,\n                self.dropout, self.out_proj.weight, self.out_proj.bias,\n                training=self.training,\n                key_padding_mask=key_padding_mask, need_weights=need_weights,\n                attn_mask=attn_mask, use_separate_proj_weight=True,\n                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,\n                v_proj_weight=self.v_proj_weight)\n        else:\n            return multi_head_attention_forward(\n                query, key, value, self.embed_dim, self.num_heads,\n                self.in_proj_weight, self.in_proj_bias,\n                self.bias_k, self.bias_v, self.add_zero_attn,\n                self.dropout, self.out_proj.weight, self.out_proj.bias,\n                training=self.training,\n                key_padding_mask=key_padding_mask, need_weights=need_weights,\n                attn_mask=attn_mask)"
  },
  {
    "path": "llava/model/semsam/modules/criterion_id_llm.py",
    "content": "# ------------------------------------------------------------------------\n# Copyright (c) IDEA, Inc. and its affiliates.\n# Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang.\n# ------------------------------------------------------------------------\n\"\"\"\nMaskFormer criterion.\n\"\"\"\nimport logging\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom detectron2.utils.comm import get_world_size\nfrom detectron2.projects.point_rend.point_features import (\n    get_uncertain_point_coords_with_randomness,\n    point_sample,\n)\n\nfrom ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list\nfrom ..utils import box_ops\nfrom utils.utils import slprint\n\ndef sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):\n    \"\"\"\n    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n        alpha: (optional) Weighting factor in range (0,1) to balance\n                positive vs negative examples. Default = -1 (no weighting).\n        gamma: Exponent of the modulating factor (1 - p_t) to\n               balance easy vs hard examples.\n    Returns:\n        Loss tensor\n    \"\"\"\n    prob = inputs.sigmoid()\n    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    p_t = prob * targets + (1 - prob) * (1 - targets)\n    loss = ce_loss * ((1 - p_t) ** gamma)\n\n    if alpha >= 0:\n        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)\n        loss = alpha_t * loss\n    return loss\n    # return loss.mean(1).sum() / num_boxes\n\n\ndef dice_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * (inputs * targets).sum(-1)\n    denominator = inputs.sum(-1) + targets.sum(-1)\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    # only match the lowest loss\n    # loss = loss.view(-1, 3)\n    # loss = loss.min(1)[0]\n    # return loss.sum() / num_masks\n    return loss\n\n\ndef iou_score_loss(inputs, targets):\n    ce_loss = F.mse_loss(inputs, targets, reduction=\"none\")\n    return ce_loss\n\n\ndice_loss_jit = torch.jit.script(\n    dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef sigmoid_ce_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    loss = loss.mean(1)\n    # loss = loss.view(-1, 3).min(1)[0]\n\n    # return loss.sum() / num_masks\n    return loss\n\n\nsigmoid_ce_loss_jit = torch.jit.script(\n    sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef calculate_uncertainty(logits):\n    \"\"\"\n    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the\n        foreground class in `classes`.\n    Args:\n        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or\n            class-agnostic, where R is the total number of predicted masks in all images and C is\n            the number of foreground classes. The values are logits.\n    Returns:\n        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with\n            the most uncertain locations having the highest uncertainty score.\n    \"\"\"\n    assert logits.shape[1] == 1\n    gt_class_logits = logits.clone()\n    return -(torch.abs(gt_class_logits))\n\n\nclass SetCriterionLLM(nn.Module):\n    \"\"\"This class computes the loss for DETR.\n    The process happens in two steps:\n        1) we compute hungarian assignment between ground truth boxes and the outputs of the model\n        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)\n    \"\"\"\n\n    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,\n                 num_points, oversample_ratio, importance_sample_ratio, dn=\"no\", dn_losses=[], panoptic_on=False,\n                 semantic_ce_loss=False, num_mask_tokens=3, iou_loss=True):\n        \"\"\"Create the criterion.\n        Parameters:\n            num_classes: number of object categories, omitting the special no-object category\n            matcher: module able to compute a matching between targets and proposals\n            weight_dict: dict containing as key the names of the losses and as values their relative weight.\n            eos_coef: relative classification weight applied to the no-object category\n            losses: list of all the losses to be applied. See get_loss for list of available losses.\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_classes_part = -1\n        self.matcher = matcher\n        self.weight_dict = weight_dict\n        self.eos_coef = eos_coef\n        self.losses = losses\n        self.dn = dn\n        self.dn_losses = dn_losses\n        empty_weight = torch.ones(self.num_classes + 1)\n        empty_weight[-1] = self.eos_coef\n        self.register_buffer(\"empty_weight\", empty_weight)\n\n        # pointwise mask loss parameters\n        self.num_points = num_points\n        self.oversample_ratio = oversample_ratio\n        self.importance_sample_ratio = importance_sample_ratio\n        self.focal_alpha = 0.25\n\n        self.panoptic_on = panoptic_on\n        self.semantic_ce_loss = semantic_ce_loss\n        self.num_mask_tokens = num_mask_tokens\n        self.index = None\n        self.iou_loss = iou_loss\n        self.prediction_switch = None\n        self.index_switch = {'part': torch.arange(0, self.num_mask_tokens - 1).cuda(),\n                             'whole': torch.arange(self.num_mask_tokens - 1, self.num_mask_tokens).cuda(),\n                             'all': torch.arange(0, self.num_mask_tokens).cuda(), }\n        # self.dbg_f=open(\"/comp_robot/zhanghao/model/idino_llama_coco/dbg\",\"a\")\n        self.keys=\"loss_bbox_0, loss_giou_0, loss_mask_bce_0, loss_mask_dice_0, iou_score_loss_0, loss_bbox_1, loss_giou_1, loss_mask_bce_1, loss_mask_dice_1, iou_score_loss_1, loss_bbox_2, loss_giou_2, loss_mask_bce_2, loss_mask_dice_2, iou_score_loss_2, loss_bbox_3, loss_giou_3, loss_mask_bce_3, loss_mask_dice_3, iou_score_loss_3, loss_bbox_4, loss_giou_4, loss_mask_bce_4, loss_mask_dice_4, iou_score_loss_4, loss_bbox_5, loss_giou_5, loss_mask_bce_5, loss_mask_dice_5, iou_score_loss_5, loss_bbox_6, loss_giou_6, loss_mask_bce_6, loss_mask_dice_6, iou_score_loss_6, loss_bbox_7, loss_giou_7, loss_mask_bce_7, loss_mask_dice_7, iou_score_loss_7, loss_bbox_8, loss_giou_8, loss_mask_bce_8, loss_mask_dice_8, iou_score_loss_8\"\n        self.keys=self.keys.split(\", \")\n        print(\"iou_loss is \", iou_loss)\n\n    def loss_labels_ce(self, outputs, targets, indices, num_masks):\n        \"\"\"Classification loss (NLL)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        assert \"pred_logits\" in outputs\n        src_logits = outputs[\"pred_logits\"]\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[\"labels\"][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(\n            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device\n        )\n        target_classes[idx] = target_classes_o\n\n        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)\n        losses = {\"loss_ce\": loss_ce}\n        return losses\n\n    def loss_labels(self, outputs, targets, indices, num_boxes, log=True, key='gt_whole_classes'):\n        \"\"\"Classification loss (Binary focal loss)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        # assert 'pred_logits' in outputs\n        if self.prediction_switch is None or 'whole' not in self.prediction_switch.keys():\n            if 'labels' in targets[0].keys() and targets[0]['labels'] is not None:\n                key = 'labels'\n        else:\n            if not self.prediction_switch['whole']:\n                return {\"fake_no_loss_mask_cls_0\": 0.0}\n            elif key not in targets[0].keys():\n                # FIXME only consider batchsize=1 case\n                assert len(targets) == 1\n                return {\"loss_mask_cls_0\": 0.0 * outputs['pred_logits'].sum()}\n        src_logits = outputs['pred_logits']\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[key][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(src_logits.shape[:2], self.num_classes,\n                                    dtype=torch.int64, device=src_logits.device)\n        target_classes[idx] = target_classes_o\n\n        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],\n                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)\n        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)\n\n        target_classes_onehot = target_classes_onehot[:, :, :-1]\n        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \\\n                  src_logits.shape[1]\n        losses = {}\n        loss_ce = loss_ce.sum(2)\n        losses[\"loss_mask_cls_0\"] = torch.gather(loss_ce.view(-1, 3), 1, self.index.unsqueeze(1)).mean().sum() / num_boxes\n        # losses = {\"loss_mask_cls_0\": loss_ce}\n        # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()}\n        return losses\n\n    def loss_labels_part(self, outputs, targets, indices, num_boxes, log=True, key='gt_part_classes'):\n        \"\"\"Classification loss (Binary focal loss)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        # assert 'pred_logits_part' in outputs\n        if not self.prediction_switch['part']:\n            return {\"fake_no_loss_mask_part_cls_0\": 0.0}\n        elif key not in targets[0].keys():\n            # FIXME only consider batchsize=1 case\n            assert len(targets) == 1\n            # return {\"loss_mask_whole_cls_0\": 0.0*outputs['pred_logits_part'].sum()}\n            return {\"loss_mask_part_cls_0\": 0.0 * outputs['pred_logits_part'].sum()}\n        src_logits = outputs['pred_logits_part']\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[key][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(src_logits.shape[:2], self.num_classes_part,\n                                    dtype=torch.int64, device=src_logits.device)\n        target_classes[idx] = target_classes_o\n\n        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],\n                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)\n        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)\n\n        target_classes_onehot = target_classes_onehot[:, :, :-1]\n        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \\\n                  src_logits.shape[1]\n        losses = {}\n        loss_ce = loss_ce.sum(2)\n        losses[\"loss_mask_part_cls_0\"] = torch.gather(loss_ce.view(-1, 3), 1, self.index.unsqueeze(1)).mean().sum() / num_boxes\n        # losses = {\"loss_mask_part_cls_0\": loss_ce}\n\n        return losses\n\n    def loss_boxes_o365(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None):\n        \"\"\"Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss\n           targets dicts must contain the key \"boxes\" containing a tensor of dim [nb_target_boxes, 4]\n           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.\n        \"\"\"\n        assert 'pred_boxes' in outputs\n        if indices is None or len(targets) == 0:\n            loss = outputs['pred_boxes'].sum() * 0.0\n            losses = {\"loss_bbox_0\": loss, \"loss_giou_0\": loss}\n            return losses\n\n        idx = self._get_src_permutation_idx(indices)\n        src_boxes = outputs['pred_boxes'][idx]\n        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n\n        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')\n        losses = {}\n        losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes\n\n        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(\n            box_ops.box_cxcywh_to_xyxy(src_boxes),\n            box_ops.box_cxcywh_to_xyxy(target_boxes)))\n        losses['loss_giou_0'] = loss_giou.sum() / num_boxes\n\n        return losses\n\n    def loss_boxes(self, outputs, targets, indices, num_boxes):\n        \"\"\"Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss\n           targets dicts must contain the key \"boxes\" containing a tensor of dim [nb_target_boxes, 4]\n           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.\n        \"\"\"\n\n        assert 'pred_boxes' in outputs\n        if 'boxes' not in targets[0].keys():\n            # FIXME only consider batchsize=1 case\n            assert len(targets) == 1\n            return {\"loss_bbox_0\": 0.0 * outputs['pred_boxes'].sum(),\n                    \"loss_giou_0\": 0.0 * outputs['pred_boxes'].sum(), }\n        assert self.index is not None\n        idx = self._get_src_permutation_idx(indices)\n        src_boxes = outputs['pred_boxes'][idx]\n        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n        # print(src_boxes)\n        # print(target_boxes)\n        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')\n        losses = {}\n        loss_bbox = loss_bbox.sum(1)\n        # losses[\"loss_bbox_0\"] = loss_bbox.sum() / num_boxes\n        try:\n            losses[\"loss_bbox_0\"] = torch.gather(loss_bbox.view(-1, 3), 1, self.index.unsqueeze(1)).sum() / num_boxes\n\n            loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(\n                box_ops.box_cxcywh_to_xyxy(src_boxes),\n                box_ops.box_cxcywh_to_xyxy(target_boxes)))\n            losses[\"loss_giou_0\"] = torch.gather(loss_giou.view(-1, 3), 1, self.index.unsqueeze(1)).sum() / num_boxes\n        except:\n            losses[\"loss_bbox_0\"] = loss_bbox.sum()*0.0\n            losses[\"loss_giou_0\"] = loss_bbox.sum()*0.0\n            print(loss_bbox.view(-1, 3))\n            print(self.index.unsqueeze(1))\n        # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()}\n\n        return losses\n\n    def loss_boxes_panoptic(self, outputs, targets, indices, num_boxes):\n        \"\"\"Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss\n           targets dicts must contain the key \"boxes\" containing a tensor of dim [nb_target_boxes, 4]\n           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.\n        \"\"\"\n        assert 'pred_boxes' in outputs\n        idx = self._get_src_permutation_idx(indices)\n        src_boxes = outputs['pred_boxes'][idx]\n        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n        target_labels = torch.cat([t['labels'][i] for t, (_, i) in zip(targets, indices)], dim=0)\n        isthing = target_labels < 80\n        target_boxes = target_boxes[isthing]\n        src_boxes = src_boxes[isthing]\n\n        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')\n        losses = {}\n        losses[\"loss_bbox_0\"] = loss_bbox.sum() / num_boxes\n\n        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(\n            box_ops.box_cxcywh_to_xyxy(src_boxes),\n            box_ops.box_cxcywh_to_xyxy(target_boxes)))\n        losses[\"loss_giou_0\"] = loss_giou.sum() / num_boxes\n\n        return losses\n\n    def loss_masks(self, outputs, targets, indices, num_masks):\n        \"\"\"Compute the losses related to the masks: the focal loss and the dice loss.\n        targets dicts must contain the key \"masks\" containing a tensor of dim [nb_target_boxes, h, w]\n        \"\"\"\n        assert \"pred_masks\" in outputs\n        if 'masks' not in targets[0].keys():\n            # FIXME only consider batchsize=1 case\n            assert len(targets) == 1\n            return {\"loss_mask_bce_0\": 0.0 * outputs['pred_masks'].sum(),\n                    \"loss_mask_dice_0\": 0.0 * outputs['pred_masks'].sum(),\n                    \"iou_score_loss_0\": 0.0 * outputs['pred_masks'].sum(),\n                    }\n\n        src_idx = self._get_src_permutation_idx(indices)\n        tgt_idx = self._get_tgt_permutation_idx(indices)\n        src_masks = outputs[\"pred_masks\"]\n        src_masks = src_masks[src_idx]\n        masks = [t[\"masks\"] for t in targets]\n        # TODO use valid to mask invalid areas due to padding in loss\n        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()\n        target_masks = target_masks.to(src_masks)\n        target_masks = target_masks[tgt_idx]\n\n        # No need to upsample predictions as we are using normalized coordinates :)\n        # N x 1 x H x W\n        # import pdb;pdb.set_trace()\n        src_masks = src_masks[:, None]\n        target_masks = target_masks[:, None]\n\n        with torch.no_grad():\n            # sample point_coords\n            point_coords = get_uncertain_point_coords_with_randomness(\n                src_masks.float(),\n                lambda logits: calculate_uncertainty(logits.float()),\n                self.num_points,\n                self.oversample_ratio,\n                self.importance_sample_ratio,\n            )\n            # get gt labels\n            point_labels = point_sample(\n                target_masks.float(),\n                point_coords.float(),\n                align_corners=False,\n            ).squeeze(1)\n\n        point_logits = point_sample(\n            src_masks.float(),\n            point_coords.float(),\n            align_corners=False,\n        ).squeeze(1)\n\n        losses = {\n            \"loss_mask_bce_0\": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),\n            # \"loss_mask_bce_0\": sigmoid_ce_loss(point_logits, point_labels, num_masks),\n            \"loss_mask_dice_0\": dice_loss_jit(point_logits, point_labels, num_masks),\n            # \"loss_mask_dice_0\": dice_loss(point_logits, point_labels, num_masks),\n        }\n        mask_loss = losses[\"loss_mask_bce_0\"] + losses[\"loss_mask_dice_0\"]\n        mask_loss, index = mask_loss.view(-1, 3).min(1)\n        # slprint(index)\n        # if len(targets)>1:   # FIXME starting box index is the same\n        #     assert targets[0]['box_start'] == targets[1]['box_start']\n        bs = outputs[\"pred_masks\"].shape[0]\n        box_start = targets[0]['box_start']\n        # index.view(bs, -1)[:, box_start:] = 0   # all the box index is set to 0\n        if self.index is None:\n            self.index = index\n        else:\n            index=self.index\n        losses[\"loss_mask_bce_0\"] = torch.gather(losses[\"loss_mask_bce_0\"].view(-1, 3), 1,\n                                                 index.unsqueeze(1)).sum() / num_masks\n        dice_loss = torch.gather(losses[\"loss_mask_dice_0\"].view(-1, 3), 1, index.unsqueeze(1))\n        losses[\"loss_mask_dice_0\"] = dice_loss.sum() / num_masks\n\n        target_iou = 1 - dice_loss\n        src_ious = outputs[\"pred_ious\"]\n        iou_idx = ([src_idx[0].view(bs, -1)[:, :src_ious.shape[1]].flatten(),\n                    src_idx[1].view(bs, -1)[:, :src_ious.shape[1]].flatten()])\n        # print(\"loss_masks1\")\n        # slprint(target_iou)\n        # print(\"loss_masks2\")\n        #\n        # slprint(src_ious)\n        # print(\"loss_masks3\")\n        #\n        # slprint(iou_idx)\n        # print(\"loss_masks4\")\n        #\n        # slprint(index.unsqueeze(1))\n        src_ious = src_ious[iou_idx]\n        src_ious = torch.gather(src_ious, 1, index.unsqueeze(1))\n        #\n        # if self.iou_loss:\n        losses['iou_score_loss_0'] = iou_score_loss(src_ious, target_iou).sum() / num_masks\n        # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()}\n        del src_masks\n        del target_masks\n        return losses\n\n    def loss_labels_o365(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None):\n        \"\"\"Classification loss (Binary focal loss)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        assert 'pred_logits' in outputs\n        if indices is None or len(targets) == 0:\n            loss_ce = outputs['pred_logits'].sum() * 0.0\n            losses = {\"loss_mask_cls_0\": loss_ce}\n            return losses\n\n        src_logits = outputs['pred_logits']\n\n        idx = self._get_src_permutation_idx(indices)\n        target_classes_o = torch.cat([t[\"labels\"][J] for t, (_, J) in zip(targets, indices)])\n        target_classes = torch.full(src_logits.shape[:2], self.num_classes,\n                                    dtype=torch.int64, device=src_logits.device)\n        target_classes[idx] = target_classes_o\n\n        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],\n                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)\n        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)\n\n        target_classes_onehot = target_classes_onehot[:,:,:-1]\n        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2).mean(1).sum()\n        losses = {'loss_mask_cls_0': loss_ce}\n\n        return losses\n\n    def prep_for_dn(self, mask_dict):\n        output_known_lbs_bboxes = mask_dict['output_known_lbs_bboxes']\n\n        known_indice = mask_dict['known_indice']\n        scalar, pad_size = mask_dict['scalar'], mask_dict['pad_size']\n        assert pad_size % scalar == 0\n        single_pad = pad_size // scalar\n\n        num_tgt = known_indice.numel()\n        return output_known_lbs_bboxes, num_tgt, single_pad, scalar\n\n    def _get_src_permutation_idx(self, indices):\n        # permute predictions following indices\n        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])\n        src_idx = torch.cat([src for (src, _) in indices])\n        return batch_idx, src_idx\n\n    def _get_tgt_permutation_idx(self, indices):\n        # permute targets following indices\n        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])\n        tgt_idx = torch.cat([tgt for (_, tgt) in indices])\n        return batch_idx, tgt_idx\n\n    def get_loss(self, loss, outputs, targets, indices, num_masks):\n        loss_map = {\n            'labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels,\n            'labels_o365':  self.loss_labels_o365,\n            'labels_part': self.loss_labels_part,\n            'masks': self.loss_masks,\n            'boxes': self.loss_boxes_panoptic if self.panoptic_on else self.loss_boxes,\n            'boxes_o365': self.loss_boxes_o365,\n        }\n        assert loss in loss_map, f\"do you really want to compute {loss} loss?\"\n        return loss_map[loss](outputs, targets, indices, num_masks)\n\n    def forward(self, outputs, targets, mask_dict=None, task='sam', extra={}, return_idx=False):\n        \"\"\"This performs the loss computation.\n        Parameters:\n             outputs: dict of tensors, see the output specification of the model for the format\n             targets: list of dicts, such that len(targets) == batch_size.\n                      The expected keys in each dict depends on the losses applied, see each loss' doc\n        \"\"\"\n        # outputs_without_aux = {k: v for k, v in outputs.items() if k != \"aux_outputs\"}\n\n        # Retrieve the matching between the outputs of the last layer and the targets\n        # if self.dn is not \"no\" and mask_dict is not None:\n        #     output_known_lbs_bboxes,num_tgt,single_pad,scalar = self.prep_for_dn(mask_dict)\n        assert len(targets)==1, \"now only support one image training for interactive segmentation\"\n        prediction_switch = extra\n        self.prediction_switch = prediction_switch\n\n        exc_idx = []\n        key = 'pred_boxes'\n        for i in range(len(targets)):\n            if len(targets[i]['boxes']) > 0:\n                if task=='det':\n                    tgt_idx = torch.arange(0, len(targets[i]['boxes'])).long().cuda()\n                else:\n                    tgt_idx = torch.arange(0, len(targets[i]['boxes'])).long().cuda().repeat_interleave(\n                        self.num_mask_tokens)\n                src_idx = torch.arange(0, outputs[key].shape[1]).long().cuda()\n                # tgt_idx = t.flatten()\n                # output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t\n                # output_idx = output_idx.flatten()\n            else:\n                output_idx = tgt_idx = src_idx = torch.tensor([]).long().cuda()\n            exc_idx.append((src_idx, tgt_idx))\n        indices = exc_idx\n        # indices = self.matcher(outputs_without_aux, targets)\n        # Compute the average number of target boxes accross all nodes, for normalization purposes\n        num_masks = sum(len(t[\"boxes\"]) for t in targets)\n        num_masks = torch.as_tensor(\n            [num_masks], dtype=torch.float, device=outputs[key].device\n        )\n        if is_dist_avail_and_initialized():\n            torch.distributed.all_reduce(num_masks)\n        num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()\n\n        # Compute all the requested losses\n        losses = {}\n        losses['num_masks'] = num_masks\n        if 'masks' in 'masks':\n            assert 'masks' in self.losses[0], \"must calculate mask loss first for match\"\n        # slprint(outputs)\n        # slprint(targets)\n        # slprint(indices)\n        for loss in self.losses:\n            if task=='det':\n                if loss=='labels_part':\n                    continue\n                if  loss=='labels':\n                    loss='labels_o365'\n                if  loss=='boxes':\n                    loss='boxes_o365'\n                if loss == 'masks':\n                    l_dict = dict()\n                    l_dict['loss_mask_bce_0'] = torch.as_tensor(0.).to('cuda')\n                    l_dict['loss_mask_dice_0'] = torch.as_tensor(0.).to('cuda')\n                    losses.update(l_dict)\n            else:\n                if 'labels' in loss:\n                    continue\n                losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))\n        index=self.index\n        if \"aux_outputs\" in outputs:\n            for i, aux_outputs in enumerate(outputs[\"aux_outputs\"]):\n                # indices = self.matcher(aux_outputs, targets)\n                for loss in self.losses:\n                    if task == 'det':\n                        if loss == 'labels_part':\n                            continue\n                        if loss == 'labels':\n                            loss = 'labels_o365'\n                        if loss == 'boxes':\n                            loss = 'boxes_o365'\n                        if loss=='masks':\n                            l_dict=dict()\n                            l_dict['loss_mask_bce_0'] = torch.as_tensor(0.).to('cuda')\n                            l_dict['loss_mask_dice_0'] = torch.as_tensor(0.).to('cuda')\n                            l_dict = {k.replace('_0', f\"_{i + 1}\"): v for k, v in l_dict.items()}\n                            losses.update(l_dict)\n                    else:\n                        if 'labels' in loss:\n                            continue\n                        l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)\n                        # l_dict = {k + f\"_{i}\": v for k, v in l_dict.items()}\n                        l_dict = {k.replace('_0', f\"_{i + 1}\"): v for k, v in l_dict.items()}\n                        losses.update(l_dict)\n        # totoal_loss = torch.tensor(0.0).cuda()\n        # for k,v in losses.items():\n        #     totoal_loss += v\n        # losses = dict()\n        # losses['all'] = totoal_loss\n        # assert \"iou_score_loss_0\" in losses, losses.keys()\n        # self.dbg_f.write(\", \".join(list(losses.keys()))+'\\n')\n        losses={k:losses[k] for k in losses.keys()}\n        if return_idx:\n            return losses,index\n        else:\n            return losses\n\n    def __repr__(self):\n        head = \"Criterion \" + self.__class__.__name__\n        body = [\n            \"matcher: {}\".format(self.matcher.__repr__(_repr_indent=8)),\n            \"losses: {}\".format(self.losses),\n            \"weight_dict: {}\".format(self.weight_dict),\n            \"num_classes: {}\".format(self.num_classes),\n            \"eos_coef: {}\".format(self.eos_coef),\n            \"num_points: {}\".format(self.num_points),\n            \"oversample_ratio: {}\".format(self.oversample_ratio),\n            \"importance_sample_ratio: {}\".format(self.importance_sample_ratio),\n        ]\n        _repr_indent = 4\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/semsam/modules/hooks.py",
    "content": "import logging\nimport numpy as np\nimport time\nimport weakref\nfrom typing import List, Mapping, Optional\nimport torch\nfrom torch.nn.parallel import DataParallel, DistributedDataParallel\n\nimport detectron2.utils.comm as comm\nfrom detectron2.utils.events import EventStorage, get_event_storage\nfrom detectron2.utils.logger import _log_api_usage\n\nclass HookBase:\n    \"\"\"\n    Base class for hooks that can be registered with :class:`TrainerBase`.\n\n    Each hook can implement 4 methods. The way they are called is demonstrated\n    in the following snippet:\n    ::\n        hook.before_train()\n        for iter in range(start_iter, max_iter):\n            hook.before_step()\n            trainer.run_step()\n            hook.after_step()\n        iter += 1\n        hook.after_train()\n\n    Notes:\n        1. In the hook method, users can access ``self.trainer`` to access more\n           properties about the context (e.g., model, current iteration, or config\n           if using :class:`DefaultTrainer`).\n\n        2. A hook that does something in :meth:`before_step` can often be\n           implemented equivalently in :meth:`after_step`.\n           If the hook takes non-trivial time, it is strongly recommended to\n           implement the hook in :meth:`after_step` instead of :meth:`before_step`.\n           The convention is that :meth:`before_step` should only take negligible time.\n\n           Following this convention will allow hooks that do care about the difference\n           between :meth:`before_step` and :meth:`after_step` (e.g., timer) to\n           function properly.\n\n    \"\"\"\n\n    trainer: \"TrainerBase\" = None\n    \"\"\"\n    A weak reference to the trainer object. Set by the trainer when the hook is registered.\n    \"\"\"\n\n    def before_train(self):\n        \"\"\"\n        Called before the first iteration.\n        \"\"\"\n        pass\n\n    def after_train(self):\n        \"\"\"\n        Called after the last iteration.\n        \"\"\"\n        pass\n\n    def before_step(self):\n        \"\"\"\n        Called before each iteration.\n        \"\"\"\n        pass\n\n    def after_step(self):\n        \"\"\"\n        Called after each iteration.\n        \"\"\"\n        pass\n\n    def state_dict(self):\n        \"\"\"\n        Hooks are stateless by default, but can be made checkpointable by\n        implementing `state_dict` and `load_state_dict`.\n        \"\"\"\n        return {}\n    \n# -*- coding: utf-8 -*-\n# Copyright (c) Facebook, Inc. and its affiliates.\n\nimport datetime\nimport itertools\nimport logging\nimport math\nimport operator\nimport os\nimport tempfile\nimport time\nimport warnings\nfrom collections import Counter\nimport torch\nfrom fvcore.common.checkpoint import Checkpointer\nfrom fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer\nfrom fvcore.common.param_scheduler import ParamScheduler\nfrom fvcore.common.timer import Timer\nfrom fvcore.nn.precise_bn import get_bn_modules, update_bn_stats\n\nimport detectron2.utils.comm as comm\nfrom detectron2.evaluation.testing import flatten_results_dict\nfrom detectron2.solver import LRMultiplier\nfrom detectron2.utils.events import EventStorage, EventWriter\nfrom detectron2.utils.file_io import PathManager\n\n# from .train_net_check import HookBase\n\n# __all__ = [\n#     \"CallbackHook\",\n#     \"IterationTimer\",\n#     \"PeriodicWriter\",\n#     \"PeriodicCheckpointer\",\n#     \"BestCheckpointer\",\n#     \"LRScheduler\",\n#     \"AutogradProfiler\",\n#     \"EvalHook\",\n#     \"PreciseBN\",\n#     \"TorchProfiler\",\n#     \"TorchMemoryStats\",\n# ]\n\n\n\"\"\"\nImplement some common hooks.\n\"\"\"\n\n\nclass CallbackHook(HookBase):\n    \"\"\"\n    Create a hook using callback functions provided by the user.\n    \"\"\"\n\n    def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):\n        \"\"\"\n        Each argument is a function that takes one argument: the trainer.\n        \"\"\"\n        self._before_train = before_train\n        self._before_step = before_step\n        self._after_step = after_step\n        self._after_train = after_train\n\n    def before_train(self):\n        if self._before_train:\n            self._before_train(self.trainer)\n\n    def after_train(self):\n        if self._after_train:\n            self._after_train(self.trainer)\n        # The functions may be closures that hold reference to the trainer\n        # Therefore, delete them to avoid circular reference.\n        del self._before_train, self._after_train\n        del self._before_step, self._after_step\n\n    def before_step(self):\n        if self._before_step:\n            self._before_step(self.trainer)\n\n    def after_step(self):\n        if self._after_step:\n            self._after_step(self.trainer)\n\n\nclass IterationTimer(HookBase):\n    \"\"\"\n    Track the time spent for each iteration (each run_step call in the trainer).\n    Print a summary in the end of training.\n\n    This hook uses the time between the call to its :meth:`before_step`\n    and :meth:`after_step` methods.\n    Under the convention that :meth:`before_step` of all hooks should only\n    take negligible amount of time, the :class:`IterationTimer` hook should be\n    placed at the beginning of the list of hooks to obtain accurate timing.\n    \"\"\"\n\n    def __init__(self, warmup_iter=3):\n        \"\"\"\n        Args:\n            warmup_iter (int): the number of iterations at the beginning to exclude\n                from timing.\n        \"\"\"\n        self._warmup_iter = warmup_iter\n        self._step_timer = Timer()\n        self._start_time = time.perf_counter()\n        self._total_timer = Timer()\n\n    def before_train(self):\n        self._start_time = time.perf_counter()\n        self._total_timer.reset()\n        self._total_timer.pause()\n\n    def after_train(self):\n        logger = logging.getLogger(__name__)\n        total_time = time.perf_counter() - self._start_time\n        total_time_minus_hooks = self._total_timer.seconds()\n        hook_time = total_time - total_time_minus_hooks\n\n        num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter\n\n        if num_iter > 0 and total_time_minus_hooks > 0:\n            # Speed is meaningful only after warmup\n            # NOTE this format is parsed by grep in some scripts\n            logger.info(\n                \"Overall training speed: {} iterations in {} ({:.4f} s / it)\".format(\n                    num_iter,\n                    str(datetime.timedelta(seconds=int(total_time_minus_hooks))),\n                    total_time_minus_hooks / num_iter,\n                )\n            )\n\n        logger.info(\n            \"Total training time: {} ({} on hooks)\".format(\n                str(datetime.timedelta(seconds=int(total_time))),\n                str(datetime.timedelta(seconds=int(hook_time))),\n            )\n        )\n\n    def before_step(self):\n        self._step_timer.reset()\n        self._total_timer.resume()\n\n    def after_step(self):\n        # +1 because we're in after_step, the current step is done\n        # but not yet counted\n        iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1\n        if iter_done >= self._warmup_iter:\n            sec = self._step_timer.seconds()\n            self.trainer.storage.put_scalars(time=sec)\n        else:\n            self._start_time = time.perf_counter()\n            self._total_timer.reset()\n\n        self._total_timer.pause()\n\n\nclass PeriodicWriter(HookBase):\n    \"\"\"\n    Write events to EventStorage (by calling ``writer.write()``) periodically.\n\n    It is executed every ``period`` iterations and after the last iteration.\n    Note that ``period`` does not affect how data is smoothed by each writer.\n    \"\"\"\n\n    def __init__(self, writers, period=20):\n        \"\"\"\n        Args:\n            writers (list[EventWriter]): a list of EventWriter objects\n            period (int):\n        \"\"\"\n        self._writers = writers\n        for w in writers:\n            assert isinstance(w, EventWriter), w\n        self._period = period\n\n    def after_step(self):\n        if (self.trainer.iter + 1) % self._period == 0 or (\n            self.trainer.iter == self.trainer.max_iter - 1\n        ):\n            for writer in self._writers:\n                writer.write()\n\n    def after_train(self):\n        for writer in self._writers:\n            # If any new data is found (e.g. produced by other after_train),\n            # write them before closing\n            writer.write()\n            writer.close()\n\n\nclass PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):\n    \"\"\"\n    Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.\n\n    Note that when used as a hook,\n    it is unable to save additional data other than what's defined\n    by the given `checkpointer`.\n\n    It is executed every ``period`` iterations and after the last iteration.\n    \"\"\"\n\n    def before_train(self):\n        self.max_iter = self.trainer.max_iter\n\n    def after_step(self):\n        # No way to use **kwargs\n        self.step(self.trainer.iter)\n\n\nclass BestCheckpointer(HookBase):\n    \"\"\"\n    Checkpoints best weights based off given metric.\n\n    This hook should be used in conjunction to and executed after the hook\n    that produces the metric, e.g. `EvalHook`.\n    \"\"\"\n\n    def __init__(\n        self,\n        eval_period: int,\n        checkpointer: Checkpointer,\n        val_metric: str,\n        mode: str = \"max\",\n        file_prefix: str = \"model_best\",\n    ) -> None:\n        \"\"\"\n        Args:\n            eval_period (int): the period `EvalHook` is set to run.\n            checkpointer: the checkpointer object used to save checkpoints.\n            val_metric (str): validation metric to track for best checkpoint, e.g. \"bbox/AP50\"\n            mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be\n                maximized or minimized, e.g. for \"bbox/AP50\" it should be \"max\"\n            file_prefix (str): the prefix of checkpoint's filename, defaults to \"model_best\"\n        \"\"\"\n        self._logger = logging.getLogger(__name__)\n        self._period = eval_period\n        self._val_metric = val_metric\n        assert mode in [\n            \"max\",\n            \"min\",\n        ], f'Mode \"{mode}\" to `BestCheckpointer` is unknown. It should be one of {\"max\", \"min\"}.'\n        if mode == \"max\":\n            self._compare = operator.gt\n        else:\n            self._compare = operator.lt\n        self._checkpointer = checkpointer\n        self._file_prefix = file_prefix\n        self.best_metric = None\n        self.best_iter = None\n\n    def _update_best(self, val, iteration):\n        if math.isnan(val) or math.isinf(val):\n            return False\n        self.best_metric = val\n        self.best_iter = iteration\n        return True\n\n    def _best_checking(self):\n        metric_tuple = self.trainer.storage.latest().get(self._val_metric)\n        if metric_tuple is None:\n            self._logger.warning(\n                f\"Given val metric {self._val_metric} does not seem to be computed/stored.\"\n                \"Will not be checkpointing based on it.\"\n            )\n            return\n        else:\n            latest_metric, metric_iter = metric_tuple\n\n        if self.best_metric is None:\n            if self._update_best(latest_metric, metric_iter):\n                additional_state = {\"iteration\": metric_iter}\n                self._checkpointer.save(f\"{self._file_prefix}\", **additional_state)\n                self._logger.info(\n                    f\"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps\"\n                )\n        elif self._compare(latest_metric, self.best_metric):\n            additional_state = {\"iteration\": metric_iter}\n            self._checkpointer.save(f\"{self._file_prefix}\", **additional_state)\n            self._logger.info(\n                f\"Saved best model as latest eval score for {self._val_metric} is \"\n                f\"{latest_metric:0.5f}, better than last best score \"\n                f\"{self.best_metric:0.5f} @ iteration {self.best_iter}.\"\n            )\n            self._update_best(latest_metric, metric_iter)\n        else:\n            self._logger.info(\n                f\"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, \"\n                f\"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}.\"\n            )\n\n    def after_step(self):\n        # same conditions as `EvalHook`\n        next_iter = self.trainer.iter + 1\n        if (\n            self._period > 0\n            and next_iter % self._period == 0\n            and next_iter != self.trainer.max_iter\n        ):\n            self._best_checking()\n\n    def after_train(self):\n        # same conditions as `EvalHook`\n        if self.trainer.iter + 1 >= self.trainer.max_iter:\n            self._best_checking()\n\n\nclass LRScheduler(HookBase):\n    \"\"\"\n    A hook which executes a torch builtin LR scheduler and summarizes the LR.\n    It is executed after every iteration.\n    \"\"\"\n\n    def __init__(self, optimizer=None, scheduler=None):\n        \"\"\"\n        Args:\n            optimizer (torch.optim.Optimizer):\n            scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):\n                if a :class:`ParamScheduler` object, it defines the multiplier over the base LR\n                in the optimizer.\n\n        If any argument is not given, will try to obtain it from the trainer.\n        \"\"\"\n        self._optimizer = optimizer\n        self._scheduler = scheduler\n\n    def before_train(self):\n        self._optimizer = self._optimizer or self.trainer.optimizer\n        if isinstance(self.scheduler, ParamScheduler):\n            self._scheduler = LRMultiplier(\n                self._optimizer,\n                self.scheduler,\n                self.trainer.max_iter,\n                last_iter=self.trainer.iter - 1,\n            )\n        self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)\n\n    @staticmethod\n    def get_best_param_group_id(optimizer):\n        # NOTE: some heuristics on what LR to summarize\n        # summarize the param group with most parameters\n        largest_group = max(len(g[\"params\"]) for g in optimizer.param_groups)\n\n        if largest_group == 1:\n            # If all groups have one parameter,\n            # then find the most common initial LR, and use it for summary\n            lr_count = Counter([g[\"lr\"] for g in optimizer.param_groups])\n            lr = lr_count.most_common()[0][0]\n            for i, g in enumerate(optimizer.param_groups):\n                if g[\"lr\"] == lr:\n                    return i\n        else:\n            for i, g in enumerate(optimizer.param_groups):\n                if len(g[\"params\"]) == largest_group:\n                    return i\n\n    def after_step(self):\n        lr = self._optimizer.param_groups[self._best_param_group_id][\"lr\"]\n        self.trainer.storage.put_scalar(\"lr\", lr, smoothing_hint=False)\n        self.scheduler.step()\n\n    @property\n    def scheduler(self):\n        return self._scheduler or self.trainer.scheduler\n\n    def state_dict(self):\n        if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):\n            return self.scheduler.state_dict()\n        return {}\n\n    def load_state_dict(self, state_dict):\n        if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):\n            logger = logging.getLogger(__name__)\n            logger.info(\"Loading scheduler from state_dict ...\")\n            self.scheduler.load_state_dict(state_dict)\n\n\nclass TorchProfiler(HookBase):\n    \"\"\"\n    A hook which runs `torch.profiler.profile`.\n\n    Examples:\n    ::\n        hooks.TorchProfiler(\n             lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR\n        )\n\n    The above example will run the profiler for iteration 10~20 and dump\n    results to ``OUTPUT_DIR``. We did not profile the first few iterations\n    because they are typically slower than the rest.\n    The result files can be loaded in the ``chrome://tracing`` page in chrome browser,\n    and the tensorboard visualizations can be visualized using\n    ``tensorboard --logdir OUTPUT_DIR/log``\n    \"\"\"\n\n    def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):\n        \"\"\"\n        Args:\n            enable_predicate (callable[trainer -> bool]): a function which takes a trainer,\n                and returns whether to enable the profiler.\n                It will be called once every step, and can be used to select which steps to profile.\n            output_dir (str): the output directory to dump tracing files.\n            activities (iterable): same as in `torch.profiler.profile`.\n            save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/\n        \"\"\"\n        self._enable_predicate = enable_predicate\n        self._activities = activities\n        self._output_dir = output_dir\n        self._save_tensorboard = save_tensorboard\n\n    def before_step(self):\n        if self._enable_predicate(self.trainer):\n            if self._save_tensorboard:\n                on_trace_ready = torch.profiler.tensorboard_trace_handler(\n                    os.path.join(\n                        self._output_dir,\n                        \"log\",\n                        \"profiler-tensorboard-iter{}\".format(self.trainer.iter),\n                    ),\n                    f\"worker{comm.get_rank()}\",\n                )\n            else:\n                on_trace_ready = None\n            self._profiler = torch.profiler.profile(\n                activities=self._activities,\n                on_trace_ready=on_trace_ready,\n                record_shapes=True,\n                profile_memory=True,\n                with_stack=True,\n                with_flops=True,\n            )\n            self._profiler.__enter__()\n        else:\n            self._profiler = None\n\n    def after_step(self):\n        if self._profiler is None:\n            return\n        self._profiler.__exit__(None, None, None)\n        if not self._save_tensorboard:\n            PathManager.mkdirs(self._output_dir)\n            out_file = os.path.join(\n                self._output_dir, \"profiler-trace-iter{}.json\".format(self.trainer.iter)\n            )\n            if \"://\" not in out_file:\n                self._profiler.export_chrome_trace(out_file)\n            else:\n                # Support non-posix filesystems\n                with tempfile.TemporaryDirectory(prefix=\"detectron2_profiler\") as d:\n                    tmp_file = os.path.join(d, \"tmp.json\")\n                    self._profiler.export_chrome_trace(tmp_file)\n                    with open(tmp_file) as f:\n                        content = f.read()\n                with PathManager.open(out_file, \"w\") as f:\n                    f.write(content)\n\n\nclass AutogradProfiler(TorchProfiler):\n    \"\"\"\n    A hook which runs `torch.autograd.profiler.profile`.\n\n    Examples:\n    ::\n        hooks.AutogradProfiler(\n             lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR\n        )\n\n    The above example will run the profiler for iteration 10~20 and dump\n    results to ``OUTPUT_DIR``. We did not profile the first few iterations\n    because they are typically slower than the rest.\n    The result files can be loaded in the ``chrome://tracing`` page in chrome browser.\n\n    Note:\n        When used together with NCCL on older version of GPUs,\n        autograd profiler may cause deadlock because it unnecessarily allocates\n        memory on every device it sees. The memory management calls, if\n        interleaved with NCCL calls, lead to deadlock on GPUs that do not\n        support ``cudaLaunchCooperativeKernelMultiDevice``.\n    \"\"\"\n\n    def __init__(self, enable_predicate, output_dir, *, use_cuda=True):\n        \"\"\"\n        Args:\n            enable_predicate (callable[trainer -> bool]): a function which takes a trainer,\n                and returns whether to enable the profiler.\n                It will be called once every step, and can be used to select which steps to profile.\n            output_dir (str): the output directory to dump tracing files.\n            use_cuda (bool): same as in `torch.autograd.profiler.profile`.\n        \"\"\"\n        warnings.warn(\"AutogradProfiler has been deprecated in favor of TorchProfiler.\")\n        self._enable_predicate = enable_predicate\n        self._use_cuda = use_cuda\n        self._output_dir = output_dir\n\n    def before_step(self):\n        if self._enable_predicate(self.trainer):\n            self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)\n            self._profiler.__enter__()\n        else:\n            self._profiler = None\n\n\nclass EvalHook(HookBase):\n    \"\"\"\n    Run an evaluation function periodically, and at the end of training.\n\n    It is executed every ``eval_period`` iterations and after the last iteration.\n    \"\"\"\n\n    def __init__(self, eval_period, eval_function, eval_after_train=True):\n        \"\"\"\n        Args:\n            eval_period (int): the period to run `eval_function`. Set to 0 to\n                not evaluate periodically (but still evaluate after the last iteration\n                if `eval_after_train` is True).\n            eval_function (callable): a function which takes no arguments, and\n                returns a nested dict of evaluation metrics.\n            eval_after_train (bool): whether to evaluate after the last iteration\n\n        Note:\n            This hook must be enabled in all or none workers.\n            If you would like only certain workers to perform evaluation,\n            give other workers a no-op function (`eval_function=lambda: None`).\n        \"\"\"\n        self._period = eval_period\n        self._func = eval_function\n        self._eval_after_train = eval_after_train\n\n    def _do_eval(self):\n        results = self._func()\n\n        if results:\n            assert isinstance(\n                results, dict\n            ), \"Eval function must return a dict. Got {} instead.\".format(results)\n\n            flattened_results = flatten_results_dict(results)\n            for k, v in flattened_results.items():\n                try:\n                    v = float(v)\n                except Exception as e:\n                    raise ValueError(\n                        \"[EvalHook] eval_function should return a nested dict of float. \"\n                        \"Got '{}: {}' instead.\".format(k, v)\n                    ) from e\n            self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)\n\n        # Evaluation may take different time among workers.\n        # A barrier make them start the next iteration together.\n        comm.synchronize()\n\n    def after_step(self):\n        next_iter = self.trainer.iter + 1\n        if self._period > 0 and next_iter % self._period == 0:\n            # do the last eval in after_train\n            if next_iter != self.trainer.max_iter:\n                self._do_eval()\n\n    def after_train(self):\n        # This condition is to prevent the eval from running after a failed training\n        if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:\n            self._do_eval()\n        # func is likely a closure that holds reference to the trainer\n        # therefore we clean it to avoid circular reference in the end\n        del self._func\n\n\nclass PreciseBN(HookBase):\n    \"\"\"\n    The standard implementation of BatchNorm uses EMA in inference, which is\n    sometimes suboptimal.\n    This class computes the true average of statistics rather than the moving average,\n    and put true averages to every BN layer in the given model.\n\n    It is executed every ``period`` iterations and after the last iteration.\n    \"\"\"\n\n    def __init__(self, period, model, data_loader, num_iter):\n        \"\"\"\n        Args:\n            period (int): the period this hook is run, or 0 to not run during training.\n                The hook will always run in the end of training.\n            model (nn.Module): a module whose all BN layers in training mode will be\n                updated by precise BN.\n                Note that user is responsible for ensuring the BN layers to be\n                updated are in training mode when this hook is triggered.\n            data_loader (iterable): it will produce data to be run by `model(data)`.\n            num_iter (int): number of iterations used to compute the precise\n                statistics.\n        \"\"\"\n        self._logger = logging.getLogger(__name__)\n        if len(get_bn_modules(model)) == 0:\n            self._logger.info(\n                \"PreciseBN is disabled because model does not contain BN layers in training mode.\"\n            )\n            self._disabled = True\n            return\n\n        self._model = model\n        self._data_loader = data_loader\n        self._num_iter = num_iter\n        self._period = period\n        self._disabled = False\n\n        self._data_iter = None\n\n    def after_step(self):\n        next_iter = self.trainer.iter + 1\n        is_final = next_iter == self.trainer.max_iter\n        if is_final or (self._period > 0 and next_iter % self._period == 0):\n            self.update_stats()\n\n    def update_stats(self):\n        \"\"\"\n        Update the model with precise statistics. Users can manually call this method.\n        \"\"\"\n        if self._disabled:\n            return\n\n        if self._data_iter is None:\n            self._data_iter = iter(self._data_loader)\n\n        def data_loader():\n            for num_iter in itertools.count(1):\n                if num_iter % 100 == 0:\n                    self._logger.info(\n                        \"Running precise-BN ... {}/{} iterations.\".format(num_iter, self._num_iter)\n                    )\n                # This way we can reuse the same iterator\n                yield next(self._data_iter)\n\n        with EventStorage():  # capture events in a new storage to discard them\n            self._logger.info(\n                \"Running precise-BN for {} iterations...  \".format(self._num_iter)\n                + \"Note that this could produce different statistics every time.\"\n            )\n            update_bn_stats(self._model, data_loader(), self._num_iter)\n\n\nclass TorchMemoryStats(HookBase):\n    \"\"\"\n    Writes pytorch's cuda memory statistics periodically.\n    \"\"\"\n\n    def __init__(self, period=20, max_runs=10):\n        \"\"\"\n        Args:\n            period (int): Output stats each 'period' iterations\n            max_runs (int): Stop the logging after 'max_runs'\n        \"\"\"\n\n        self._logger = logging.getLogger(__name__)\n        self._period = period\n        self._max_runs = max_runs\n        self._runs = 0\n\n    def after_step(self):\n        if self._runs > self._max_runs:\n            return\n\n        if (self.trainer.iter + 1) % self._period == 0 or (\n            self.trainer.iter == self.trainer.max_iter - 1\n        ):\n            if torch.cuda.is_available():\n                max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0\n                reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0\n                max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0\n                allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0\n\n                self._logger.info(\n                    (\n                        \" iter: {} \"\n                        \" max_reserved_mem: {:.0f}MB \"\n                        \" reserved_mem: {:.0f}MB \"\n                        \" max_allocated_mem: {:.0f}MB \"\n                        \" allocated_mem: {:.0f}MB \"\n                    ).format(\n                        self.trainer.iter,\n                        max_reserved_mb,\n                        reserved_mb,\n                        max_allocated_mb,\n                        allocated_mb,\n                    )\n                )\n\n                self._runs += 1\n                if self._runs == self._max_runs:\n                    mem_summary = torch.cuda.memory_summary()\n                    self._logger.info(\"\\n\" + mem_summary)\n\n                torch.cuda.reset_peak_memory_stats()"
  },
  {
    "path": "llava/model/semsam/modules/matcher.py",
    "content": "# ------------------------------------------------------------------------\n# DINO\n# Copyright (c) 2022 IDEA. All Rights Reserved.\n# Licensed under the Apache License, Version 2.0 [see LICENSE for details]\n# ------------------------------------------------------------------------\n# Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang.\n\n\"\"\"\nModules to compute the matching cost and solve the corresponding LSAP.\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.optimize import linear_sum_assignment\nfrom torch import nn\nfrom torch.cuda.amp import autocast\n\nfrom detectron2.projects.point_rend.point_features import point_sample\nfrom ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy\n# from ..language.loss import vl_similarity\n\ndef batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * torch.einsum(\"nc,mc->nm\", inputs, targets)\n    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    return loss\n\n\nbatch_dice_loss_jit = torch.jit.script(\n    batch_dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    hw = inputs.shape[1]\n\n    pos = F.binary_cross_entropy_with_logits(\n        inputs, torch.ones_like(inputs), reduction=\"none\"\n    )\n    neg = F.binary_cross_entropy_with_logits(\n        inputs, torch.zeros_like(inputs), reduction=\"none\"\n    )\n\n    loss = torch.einsum(\"nc,mc->nm\", pos, targets) + torch.einsum(\n        \"nc,mc->nm\", neg, (1 - targets)\n    )\n\n    return loss / hw\n\n\nbatch_sigmoid_ce_loss_jit = torch.jit.script(\n    batch_sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\nclass HungarianMatcher(nn.Module):\n    \"\"\"This class computes an assignment between the targets and the predictions of the network\n\n    For efficiency reasons, the targets don't include the no_object. Because of this, in general,\n    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,\n    while the others are un-matched (and thus treated as non-objects).\n    \"\"\"\n\n    def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0,\n                 cost_box: float = 0, cost_giou: float = 0, panoptic_on: bool = False):\n        \"\"\"Creates the matcher\n\n        Params:\n            cost_class: This is the relative weight of the classification error in the matching cost\n            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost\n            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost\n        \"\"\"\n        super().__init__()\n        self.cost_class = cost_class\n        self.cost_mask = cost_mask\n        self.cost_dice = cost_dice\n        self.cost_box = cost_box\n        self.cost_giou = cost_giou\n\n        self.panoptic_on = panoptic_on\n\n        assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, \"all costs cant be 0\"\n\n        self.num_points = num_points\n\n    @torch.no_grad()\n    def memory_efficient_forward(self, outputs, targets, cost=[\"cls\", \"box\", \"mask\"]):\n        \"\"\"More memory-friendly matching. Change cost to compute only certain loss in matching\"\"\"\n        bs, num_queries = outputs[\"pred_logits\"].shape[:2]\n\n        indices = []\n\n        # Iterate through batch size\n        for b in range(bs):\n            out_bbox = outputs[\"pred_boxes\"][b]\n            if 'box' in cost:\n                tgt_bbox=targets[b][\"boxes\"]\n                cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)\n                cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))\n            else:\n                cost_bbox = torch.tensor(0).to(out_bbox)\n                cost_giou = torch.tensor(0).to(out_bbox)\n\n            out_prob = outputs[\"pred_logits\"][b].sigmoid()  # [num_queries, num_classes]\n            tgt_ids = targets[b][\"labels\"]\n            # focal loss\n            alpha = 0.25\n            gamma = 2.0\n            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-6).log())\n            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-6).log())\n            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            # cost_class = -out_prob[:, tgt_ids]\n            if 'mask' in cost:\n                out_mask = outputs[\"pred_masks\"][b]  # [num_queries, H_pred, W_pred]\n                # gt masks are already padded when preparing target\n                tgt_mask = targets[b][\"masks\"].to(out_mask)\n\n                out_mask = out_mask[:, None]\n                tgt_mask = tgt_mask[:, None]\n                # all masks share the same set of points for efficient matching!\n                point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)\n                # get gt labels\n                tgt_mask = point_sample(\n                    tgt_mask,\n                    point_coords.repeat(tgt_mask.shape[0], 1, 1),\n                    align_corners=False,\n                ).squeeze(1)\n\n                out_mask = point_sample(\n                    out_mask,\n                    point_coords.repeat(out_mask.shape[0], 1, 1),\n                    align_corners=False,\n                ).squeeze(1)\n\n                with autocast(enabled=False):\n                    out_mask = out_mask.float()\n                    tgt_mask = tgt_mask.float()\n                    # Compute the focal loss between masks\n                    cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)\n\n                    # Compute the dice loss betwen masks\n                    cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)\n            else:\n                cost_mask = torch.tensor(0).to(out_bbox)\n                cost_dice = torch.tensor(0).to(out_bbox)\n            \n            # Final cost matrix\n            if self.panoptic_on:\n                isthing = tgt_ids<80\n                cost_bbox[:, ~isthing] = cost_bbox[:, isthing].mean()\n                cost_giou[:, ~isthing] = cost_giou[:, isthing].mean()\n                cost_bbox[cost_bbox.isnan()] = 0.0\n                cost_giou[cost_giou.isnan()] = 0.0\n\n            C = (\n                self.cost_mask * cost_mask\n                + self.cost_class * cost_class\n                + self.cost_dice * cost_dice\n                + self.cost_box*cost_bbox\n                + self.cost_giou*cost_giou\n            )\n            C = C.reshape(num_queries, -1).cpu()\n            indices.append(linear_sum_assignment(C))\n\n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n\n    @torch.no_grad()\n    def grounding_forward(self, outputs, targets, extra):\n        \"\"\"More memory-friendly matching\"\"\"\n        bs, num_queries = outputs[\"pred_gmasks\"].shape[:2]\n        \n        if bs == 0 or len(targets) == 0:\n            return None\n\n        indices = []\n        # Iterate through batch size\n        for b in range(bs):\n            out_prob = outputs[\"pred_logits\"][b]\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            cost_class = -out_prob.softmax(dim=0)\n\n            out_mask = outputs[\"pred_gmasks\"][b]  # [num_queries, H_pred, W_pred]\n            # gt masks are already padded when preparing target\n            tgt_mask = targets[b][\"grounding_masks\"].to(out_mask)\n\n            out_mask = out_mask[:, None]\n            tgt_mask = tgt_mask[:, None]\n            \n            # all masks share the same set of points for efficient matching!\n            point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)\n            # get gt labels\n            tgt_mask = point_sample(\n                tgt_mask,\n                point_coords.repeat(tgt_mask.shape[0], 1, 1),\n                align_corners=False,\n            ).squeeze(1)\n\n            out_mask = point_sample(\n                out_mask,\n                point_coords.repeat(out_mask.shape[0], 1, 1),\n                align_corners=False,\n            ).squeeze(1)\n\n            with autocast(enabled=False):\n                out_mask = out_mask.float()\n                tgt_mask = tgt_mask.float()\n                # Compute the focal loss between masks\n                cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)\n\n                # Compute the dice loss betwen masks\n                cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)\n                \n            # Final cost matrix\n            C = (\n                self.cost_mask * cost_mask\n                + self.cost_class * cost_class\n                + self.cost_dice * cost_dice\n            )\n            C = C.reshape(num_queries, -1).cpu()\n            indices.append(linear_sum_assignment(C))\n\n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n\n\n    @torch.no_grad()\n    def caption_forward_womask(self, outputs, targets, extra):\n        \"\"\"More memory-friendly matching\"\"\"\n        bs, _ = outputs[\"pred_logits\"].shape[:2]\n\n        if bs == 0 or len(targets) == 0:\n            return None\n\n        indices = []\n        t_emb = torch.cat([t['captions'] for t in targets])\n        v_emb = outputs['unmatched_pred_captions']\n        caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])\n\n        # Iterate through batch size\n        for b in range(bs):\n            v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)\n            num_queries = len(v_emb[b])\n            out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]\n            tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            cost_class = -out_prob[:, tgt_ids]\n\n            # Final cost matrix\n            C = (self.cost_class * cost_class)\n            C = C.reshape(num_queries, -1).cpu()\n            indices.append(linear_sum_assignment(C))\n\n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n\n\n    @torch.no_grad()\n    def caption_forward_wmask(self, outputs, targets, extra):\n        \"\"\"More memory-friendly matching\"\"\"\n        bs, _ = outputs[\"pred_logits\"].shape[:2]\n\n        if bs == 0 or len(targets) == 0:\n            return None\n\n        indices = []\n        t_emb = torch.cat([t['captions'] for t in targets])\n        v_emb = outputs['unmatched_pred_captions']\n        caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])\n        \n        # Iterate through batch size\n        for b in range(bs):\n            v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)\n            num_queries = len(v_emb[b])\n            \n            out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]\n            tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            cost_class = -out_prob[:, tgt_ids]\n\n            out_mask = outputs[\"pred_masks\"][b]  # [num_queries, H_pred, W_pred]\n            # gt masks are already padded when preparing target\n            tgt_mask = targets[b][\"masks\"].to(out_mask)\n            \n            out_mask = out_mask[:, None]\n            tgt_mask = tgt_mask[:, None]\n            # all masks share the same set of points for efficient matching!\n            point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)\n            # get gt labels\n            tgt_mask = point_sample(\n                tgt_mask,\n                point_coords.repeat(tgt_mask.shape[0], 1, 1),\n                align_corners=False,\n            ).squeeze(1)\n\n            out_mask = point_sample(\n                out_mask,\n                point_coords.repeat(out_mask.shape[0], 1, 1),\n                align_corners=False,\n            ).squeeze(1)\n\n            with autocast(enabled=False):\n                out_mask = out_mask.float()\n                tgt_mask = tgt_mask.float()\n                # Compute the focal loss between masks\n                cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)\n\n                # Compute the dice loss betwen masks\n                cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)\n\n            # Final cost matrix\n            C = (\n                self.cost_mask * cost_mask\n                + self.cost_class * cost_class\n                + self.cost_dice * cost_dice\n            )\n            C = C.reshape(num_queries, -1).cpu()\n            indices.append(linear_sum_assignment(C))\n\n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n\n    @torch.no_grad()\n    def forward(self, outputs, targets, cost=[\"cls\", \"box\", \"mask\"], mode='default', extra={}):\n        \"\"\"Performs the matching\n\n        Params:\n            outputs: This is a dict that contains at least these entries:\n                 \"pred_logits\": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits\n                 \"pred_masks\": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks\n\n            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:\n                 \"labels\": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth\n                           objects in the target) containing the class labels\n                 \"masks\": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks\n\n        Returns:\n            A list of size batch_size, containing tuples of (index_i, index_j) where:\n                - index_i is the indices of the selected predictions (in order)\n                - index_j is the indices of the corresponding selected targets (in order)\n            For each batch element, it holds:\n                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)\n        \"\"\"\n        if mode == 'default':\n            return self.memory_efficient_forward(outputs, targets, cost)\n        elif mode == 'grounding':\n            return self.grounding_forward(outputs, targets, extra)\n        elif mode == 'caption_womask':\n            return self.caption_forward_womask(outputs, targets, extra)\n        elif mode == 'caption_wmask':\n            return self.caption_forward_wmask(outputs, targets, extra)\n        else:\n            assert False, \"Mode {} is not supported.\".format(mode)\n\n    def __repr__(self, _repr_indent=4):\n        head = \"Matcher \" + self.__class__.__name__\n        body = [\n            \"cost_class: {}\".format(self.cost_class),\n            \"cost_mask: {}\".format(self.cost_mask),\n            \"cost_dice: {}\".format(self.cost_dice),\n        ]\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/semsam/modules/point_features.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch\nfrom torch.nn import functional as F\n\nfrom detectron2.layers import cat, shapes_to_tensor\nfrom detectron2.structures import BitMasks, Boxes\n\n# from ..layers import cat, shapes_to_tensor\n# from ..structures import BitMasks, Boxes\n\n\"\"\"\nShape shorthand in this module:\n\n    N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the\n        number of images for semantic segmenation.\n    R: number of ROIs, combined over all images, in the minibatch\n    P: number of points\n\"\"\"\n\n\ndef point_sample(input, point_coords, **kwargs):\n    \"\"\"\n    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.\n    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside\n    [0, 1] x [0, 1] square.\n\n    Args:\n        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.\n        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains\n        [0, 1] x [0, 1] normalized point coordinates.\n\n    Returns:\n        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains\n            features for points in `point_coords`. The features are obtained via bilinear\n            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.\n    \"\"\"\n    add_dim = False\n    if point_coords.dim() == 3:\n        add_dim = True\n        point_coords = point_coords.unsqueeze(2)\n    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)\n    if add_dim:\n        output = output.squeeze(3)\n    return output\n\n\ndef generate_regular_grid_point_coords(R, side_size, device):\n    \"\"\"\n    Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.\n\n    Args:\n        R (int): The number of grids to sample, one for each region.\n        side_size (int): The side size of the regular grid.\n        device (torch.device): Desired device of returned tensor.\n\n    Returns:\n        (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates\n            for the regular grids.\n    \"\"\"\n    aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)\n    r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)\n    return r.view(1, -1, 2).expand(R, -1, -1)\n\n\ndef get_uncertain_point_coords_with_randomness(\n    coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio\n):\n    \"\"\"\n    Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties\n        are calculated for each point using 'uncertainty_func' function that takes point's logit\n        prediction as input.\n    See PointRend paper for details.\n\n    Args:\n        coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for\n            class-specific or class-agnostic prediction.\n        uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that\n            contains logit predictions for P points and returns their uncertainties as a Tensor of\n            shape (N, 1, P).\n        num_points (int): The number of points P to sample.\n        oversample_ratio (int): Oversampling parameter.\n        importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.\n\n    Returns:\n        point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P\n            sampled points.\n    \"\"\"\n    assert oversample_ratio >= 1\n    assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0\n    num_boxes = coarse_logits.shape[0]\n    num_sampled = int(num_points * oversample_ratio)\n    point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype)\n    point_logits = point_sample(coarse_logits, point_coords, align_corners=False)\n    # It is crucial to calculate uncertainty based on the sampled prediction value for the points.\n    # Calculating uncertainties of the coarse predictions first and sampling them for points leads\n    # to incorrect results.\n    # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between\n    # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.\n    # However, if we calculate uncertainties for the coarse predictions first,\n    # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.\n    point_uncertainties = uncertainty_func(point_logits)\n    num_uncertain_points = int(importance_sample_ratio * num_points)\n    num_random_points = num_points - num_uncertain_points\n    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]\n    shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)\n    idx += shift[:, None]\n    point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(\n        num_boxes, num_uncertain_points, 2\n    )\n    if num_random_points > 0:\n        point_coords = cat(\n            [\n                point_coords,\n                torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),\n            ],\n            dim=1,\n        )\n    return point_coords\n\n\ndef get_uncertain_point_coords_on_grid(uncertainty_map, num_points):\n    \"\"\"\n    Find `num_points` most uncertain points from `uncertainty_map` grid.\n\n    Args:\n        uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty\n            values for a set of points on a regular H x W grid.\n        num_points (int): The number of points P to select.\n\n    Returns:\n        point_indices (Tensor): A tensor of shape (N, P) that contains indices from\n            [0, H x W) of the most uncertain points.\n        point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized\n            coordinates of the most uncertain points from the H x W grid.\n    \"\"\"\n    R, _, H, W = uncertainty_map.shape\n    h_step = 1.0 / float(H)\n    w_step = 1.0 / float(W)\n\n    num_points = min(H * W, num_points)\n    point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]\n    point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)\n    point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step\n    point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step\n    return point_indices, point_coords\n\n\ndef point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):\n    \"\"\"\n    Get features from feature maps in `features_list` that correspond to specific point coordinates\n        inside each bounding box from `boxes`.\n\n    Args:\n        features_list (list[Tensor]): A list of feature map tensors to get features from.\n        feature_scales (list[float]): A list of scales for tensors in `features_list`.\n        boxes (list[Boxes]): A list of I Boxes  objects that contain R_1 + ... + R_I = R boxes all\n            together.\n        point_coords (Tensor): A tensor of shape (R, P, 2) that contains\n            [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.\n\n    Returns:\n        point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled\n            from all features maps in feature_list for P sampled points for all R boxes in `boxes`.\n        point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level\n            coordinates of P points.\n    \"\"\"\n    cat_boxes = Boxes.cat(boxes)\n    num_boxes = [b.tensor.size(0) for b in boxes]\n\n    point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)\n    split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)\n\n    point_features = []\n    for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):\n        point_features_per_image = []\n        for idx_feature, feature_map in enumerate(features_list):\n            h, w = feature_map.shape[-2:]\n            scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature]\n            point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device)\n            point_features_per_image.append(\n                point_sample(\n                    feature_map[idx_img].unsqueeze(0),\n                    point_coords_scaled.unsqueeze(0),\n                    align_corners=False,\n                )\n                .squeeze(0)\n                .transpose(1, 0)\n            )\n        point_features.append(cat(point_features_per_image, dim=1))\n\n    return cat(point_features, dim=0), point_coords_wrt_image\n\n\ndef get_point_coords_wrt_image(boxes_coords, point_coords):\n    \"\"\"\n    Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.\n\n    Args:\n        boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.\n            coordinates.\n        point_coords (Tensor): A tensor of shape (R, P, 2) that contains\n            [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.\n\n    Returns:\n        point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains\n            image-normalized coordinates of P sampled points.\n    \"\"\"\n    with torch.no_grad():\n        point_coords_wrt_image = point_coords.clone()\n        point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (\n            boxes_coords[:, None, 2] - boxes_coords[:, None, 0]\n        )\n        point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (\n            boxes_coords[:, None, 3] - boxes_coords[:, None, 1]\n        )\n        point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]\n        point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]\n    return point_coords_wrt_image\n\n\ndef sample_point_labels(instances, point_coords):\n    \"\"\"\n    Sample point labels from ground truth mask given point_coords.\n\n    Args:\n        instances (list[Instances]): A list of N Instances, where N is the number of images\n            in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is\n            equal to R. The ground-truth gt_masks in each instance will be used to compute labels.\n        points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of\n            instances and P is the number of points for each instance. The coordinates are in\n            the absolute image pixel coordinate space, i.e. [0, H] x [0, W].\n\n    Returns:\n        Tensor: A tensor of shape (R, P) that contains the labels of P sampled points.\n    \"\"\"\n    with torch.no_grad():\n        gt_mask_logits = []\n        point_coords_splits = torch.split(\n            point_coords, [len(instances_per_image) for instances_per_image in instances]\n        )\n        for i, instances_per_image in enumerate(instances):\n            if len(instances_per_image) == 0:\n                continue\n            assert isinstance(\n                instances_per_image.gt_masks, BitMasks\n            ), \"Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'.\"\n\n            gt_bit_masks = instances_per_image.gt_masks.tensor\n            h, w = instances_per_image.gt_masks.image_size\n            scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device)\n            points_coord_grid_sample_format = point_coords_splits[i] / scale\n            gt_mask_logits.append(\n                point_sample(\n                    gt_bit_masks.to(torch.float32).unsqueeze(1),\n                    points_coord_grid_sample_format,\n                    align_corners=False,\n                ).squeeze(1)\n            )\n\n    point_labels = cat(gt_mask_logits)\n    return point_labels\n"
  },
  {
    "path": "llava/model/semsam/modules/position_encoding.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py\n\"\"\"\nVarious positional encodings for the transformer.\n\"\"\"\nimport math\n\nimport torch\nfrom torch import nn\n\n\nclass PositionEmbeddingSine(nn.Module):\n    \"\"\"\n    This is a more standard version of the position embedding, very similar to the one\n    used by the Attention is all you need paper, generalized to work on images.\n    \"\"\"\n\n    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):\n        super().__init__()\n        self.num_pos_feats = num_pos_feats\n        self.temperature = temperature\n        self.normalize = normalize\n        if scale is not None and normalize is False:\n            raise ValueError(\"normalize should be True if scale is passed\")\n        if scale is None:\n            scale = 2 * math.pi\n        self.scale = scale\n\n    def forward(self, x, mask=None):\n        if mask is None:\n            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)\n        not_mask = ~mask\n        y_embed = not_mask.cumsum(1, dtype=x.dtype)\n        x_embed = not_mask.cumsum(2, dtype=x.dtype)\n        if self.normalize:\n            eps = 1e-6\n            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale\n            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale\n\n        dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)\n        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n        pos_x = x_embed[:, :, :, None] / dim_t\n        pos_y = y_embed[:, :, :, None] / dim_t\n        pos_x = torch.stack(\n            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4\n        ).flatten(3)\n        pos_y = torch.stack(\n            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4\n        ).flatten(3)\n        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)\n        return pos\n    \n    def __repr__(self, _repr_indent=4):\n        head = \"Positional encoding \" + self.__class__.__name__\n        body = [\n            \"num_pos_feats: {}\".format(self.num_pos_feats),\n            \"temperature: {}\".format(self.temperature),\n            \"normalize: {}\".format(self.normalize),\n            \"scale: {}\".format(self.scale),\n        ]\n        # _repr_indent = 4\n        lines = [head] + [\" \" * _repr_indent + line for line in body]\n        return \"\\n\".join(lines)\n"
  },
  {
    "path": "llava/model/semsam/modules/postprocessing.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport torch\nfrom torch.nn import functional as F\n\nfrom detectron2.structures import Instances, ROIMasks\n\n\n# perhaps should rename to \"resize_instance\"\ndef detector_postprocess(\n    results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5\n):\n    \"\"\"\n    Resize the output instances.\n    The input images are often resized when entering an object detector.\n    As a result, we often need the outputs of the detector in a different\n    resolution from its inputs.\n\n    This function will resize the raw outputs of an R-CNN detector\n    to produce outputs according to the desired output resolution.\n\n    Args:\n        results (Instances): the raw outputs from the detector.\n            `results.image_size` contains the input image resolution the detector sees.\n            This object might be modified in-place.\n        output_height, output_width: the desired output resolution.\n\n    Returns:\n        Instances: the resized output from the model, based on the output resolution\n    \"\"\"\n    if isinstance(output_width, torch.Tensor):\n        # This shape might (but not necessarily) be tensors during tracing.\n        # Converts integer tensors to float temporaries to ensure true\n        # division is performed when computing scale_x and scale_y.\n        output_width_tmp = output_width.float()\n        output_height_tmp = output_height.float()\n        new_size = torch.stack([output_height, output_width])\n    else:\n        new_size = (output_height, output_width)\n        output_width_tmp = output_width\n        output_height_tmp = output_height\n\n    scale_x, scale_y = (\n        output_width_tmp / results.image_size[1],\n        output_height_tmp / results.image_size[0],\n    )\n    results = Instances(new_size, **results.get_fields())\n\n    if results.has(\"pred_boxes\"):\n        output_boxes = results.pred_boxes\n    elif results.has(\"proposal_boxes\"):\n        output_boxes = results.proposal_boxes\n    else:\n        output_boxes = None\n    assert output_boxes is not None, \"Predictions must contain boxes!\"\n\n    output_boxes.scale(scale_x, scale_y)\n    output_boxes.clip(results.image_size)\n\n    results = results[output_boxes.nonempty()]\n\n    if results.has(\"pred_masks\"):\n        if isinstance(results.pred_masks, ROIMasks):\n            roi_masks = results.pred_masks\n        else:\n            # pred_masks is a tensor of shape (N, 1, M, M)\n            roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])\n        results.pred_masks = roi_masks.to_bitmasks(\n            results.pred_boxes, output_height, output_width, mask_threshold\n        ).tensor  # TODO return ROIMasks/BitMask object in the future\n\n    if results.has(\"pred_keypoints\"):\n        results.pred_keypoints[:, :, 0] *= scale_x\n        results.pred_keypoints[:, :, 1] *= scale_y\n\n    return results\n\ndef bbox_postprocess(result, input_size, img_size, output_height, output_width):\n    \"\"\"\n    result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h]\n    \"\"\"\n    if result is None:\n        return None\n    \n    scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device)\n    result = result.sigmoid() * scale\n    x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2\n    h,w = img_size\n\n    x1 = x1.clamp(min=0, max=w)\n    y1 = y1.clamp(min=0, max=h)\n    x2 = x2.clamp(min=0, max=w)\n    y2 = y2.clamp(min=0, max=h)\n\n    box = torch.stack([x1,y1,x2,y2]).permute(1,0)\n    scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device)\n    box = box*scale\n    return box\n\ndef sem_seg_postprocess(result, img_size, output_height, output_width):\n    \"\"\"\n    Return semantic segmentation predictions in the original resolution.\n\n    The input images are often resized when entering semantic segmentor. Moreover, in same\n    cases, they also padded inside segmentor to be divisible by maximum network stride.\n    As a result, we often need the predictions of the segmentor in a different\n    resolution from its inputs.\n\n    Args:\n        result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),\n            where C is the number of classes, and H, W are the height and width of the prediction.\n        img_size (tuple): image size that segmentor is taking as input.\n        output_height, output_width: the desired output resolution.\n\n    Returns:\n        semantic segmentation prediction (Tensor): A tensor of the shape\n            (C, output_height, output_width) that contains per-pixel soft predictions.\n    \"\"\"\n    result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)\n    result = F.interpolate(\n        result, size=(output_height, output_width), mode=\"bicubic\", align_corners=False\n    )[0]\n    return result\n"
  },
  {
    "path": "llava/model/semsam/utils/__init__.py",
    "content": "from .config import *\nfrom .misc import *\n# from .dist import *"
  },
  {
    "path": "llava/model/semsam/utils/box_ops.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nUtilities for bounding box manipulation and GIoU.\n\"\"\"\nimport torch\nfrom torchvision.ops.boxes import box_area\n\n\ndef box_cxcywh_to_xyxy(x):\n    x_c, y_c, w, h = x.unbind(-1)\n    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n         (x_c + 0.5 * w), (y_c + 0.5 * h)]\n    return torch.stack(b, dim=-1)\n\n\ndef box_xyxy_to_cxcywh(x):\n    x0, y0, x1, y1 = x.unbind(-1)\n    b = [(x0 + x1) / 2, (y0 + y1) / 2,\n         (x1 - x0), (y1 - y0)]\n    return torch.stack(b, dim=-1)\n\ndef box_xywh_to_xyxy(x):\n    x0, y0, x1, y1 = x.unbind(-1)\n    b = [x0, y0, (x0 + x1), (y0 + y1)]\n    return torch.stack(b, dim=-1)\n\n\n# modified from torchvision to also return the union\ndef box_iou(boxes1, boxes2):\n    area1 = box_area(boxes1)\n    area2 = box_area(boxes2)\n\n    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]\n    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]\n\n    union = area1[:, None] + area2 - inter\n\n    iou = inter / (union+1e-6)\n    return iou, union\n\n\ndef generalized_box_iou(boxes1, boxes2):\n    \"\"\"\n    Generalized IoU from https://giou.stanford.edu/\n\n    The boxes should be in [x0, y0, x1, y1] format\n\n    Returns a [N, M] pairwise matrix, where N = len(boxes1)\n    and M = len(boxes2)\n    \"\"\"\n    # degenerate boxes gives inf / nan results\n    # so do an early check\n    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()\n    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()\n    iou, union = box_iou(boxes1, boxes2)\n\n    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])\n    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])\n\n    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n    area = wh[:, :, 0] * wh[:, :, 1]\n\n    return iou - (area - union) / (area+1e-6)\n\n\ndef masks_to_boxes(masks):\n    \"\"\"Compute the bounding boxes around the provided masks\n\n    The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.\n\n    Returns a [N, 4] tensors, with the boxes in xyxy format\n    \"\"\"\n    if masks.numel() == 0:\n        return torch.zeros((0, 4), device=masks.device)\n\n    h, w = masks.shape[-2:]\n\n    y = torch.arange(0, h, dtype=torch.float)\n    x = torch.arange(0, w, dtype=torch.float)\n    y, x = torch.meshgrid(y, x)\n\n    x_mask = (masks * x.unsqueeze(0))\n    x_max = x_mask.flatten(1).max(-1)[0]\n    x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    y_mask = (masks * y.unsqueeze(0))\n    y_max = y_mask.flatten(1).max(-1)[0]\n    y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]\n\n    return torch.stack([x_min, y_min, x_max, y_max], 1)"
  },
  {
    "path": "llava/model/semsam/utils/config.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Facebook, Inc. and its affiliates.\n\nimport functools\nimport inspect\n\ndef configurable(init_func=None, *, from_config=None):\n    \"\"\"\n    Decorate a function or a class's __init__ method so that it can be called\n    with a :class:`CfgNode` object using a :func:`from_config` function that translates\n    :class:`CfgNode` to arguments.\n\n    Examples:\n    ::\n        # Usage 1: Decorator on __init__:\n        class A:\n            @configurable\n            def __init__(self, a, b=2, c=3):\n                pass\n\n            @classmethod\n            def from_config(cls, cfg):   # 'cfg' must be the first argument\n                # Returns kwargs to be passed to __init__\n                return {\"a\": cfg.A, \"b\": cfg.B}\n\n        a1 = A(a=1, b=2)  # regular construction\n        a2 = A(cfg)       # construct with a cfg\n        a3 = A(cfg, b=3, c=4)  # construct with extra overwrite\n\n        # Usage 2: Decorator on any function. Needs an extra from_config argument:\n        @configurable(from_config=lambda cfg: {\"a: cfg.A, \"b\": cfg.B})\n        def a_func(a, b=2, c=3):\n            pass\n\n        a1 = a_func(a=1, b=2)  # regular call\n        a2 = a_func(cfg)       # call with a cfg\n        a3 = a_func(cfg, b=3, c=4)  # call with extra overwrite\n\n    Args:\n        init_func (callable): a class's ``__init__`` method in usage 1. The\n            class must have a ``from_config`` classmethod which takes `cfg` as\n            the first argument.\n        from_config (callable): the from_config function in usage 2. It must take `cfg`\n            as its first argument.\n    \"\"\"\n\n    if init_func is not None:\n        assert (\n            inspect.isfunction(init_func)\n            and from_config is None\n            and init_func.__name__ == \"__init__\"\n        ), \"Incorrect use of @configurable. Check API documentation for examples.\"\n\n        @functools.wraps(init_func)\n        def wrapped(self, *args, **kwargs):\n            try:\n                from_config_func = type(self).from_config\n            except AttributeError as e:\n                raise AttributeError(\n                    \"Class with @configurable must have a 'from_config' classmethod.\"\n                ) from e\n            if not inspect.ismethod(from_config_func):\n                raise TypeError(\"Class with @configurable must have a 'from_config' classmethod.\")\n\n            # import ipdb; ipdb.set_trace()\n            if _called_with_cfg(*args, **kwargs):\n                explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)\n                init_func(self, **explicit_args)\n            else:\n                init_func(self, *args, **kwargs)\n\n        return wrapped\n\n    else:\n        if from_config is None:\n            return configurable  # @configurable() is made equivalent to @configurable\n        assert inspect.isfunction(\n            from_config\n        ), \"from_config argument of configurable must be a function!\"\n\n        def wrapper(orig_func):\n            @functools.wraps(orig_func)\n            def wrapped(*args, **kwargs):\n                if _called_with_cfg(*args, **kwargs):\n                    explicit_args = _get_args_from_config(from_config, *args, **kwargs)\n                    return orig_func(**explicit_args)\n                else:\n                    return orig_func(*args, **kwargs)\n\n            wrapped.from_config = from_config\n            return wrapped\n\n        return wrapper\n\ndef _called_with_cfg(*args, **kwargs):\n    \"\"\"\n    Returns:\n        bool: whether the arguments contain CfgNode and should be considered\n            forwarded to from_config.\n    \"\"\"\n    from omegaconf import DictConfig, OmegaConf, ListConfig\n    # from detectron2.config import LazyConfig\n\n    if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))):\n        return True\n    if isinstance(kwargs.pop(\"cfg\", None), (dict)):\n        return True\n    # `from_config`'s first argument is forced to be \"cfg\".\n    # So the above check covers all cases.\n    return False\n\ndef _get_args_from_config(from_config_func, *args, **kwargs):\n    \"\"\"\n    Use `from_config` to obtain explicit arguments.\n\n    Returns:\n        dict: arguments to be used for cls.__init__\n    \"\"\"\n    signature = inspect.signature(from_config_func)\n    if list(signature.parameters.keys())[0] != \"cfg\":\n        if inspect.isfunction(from_config_func):\n            name = from_config_func.__name__\n        else:\n            name = f\"{from_config_func.__self__}.from_config\"\n        raise TypeError(f\"{name} must take 'cfg' as the first argument!\")\n    support_var_arg = any(\n        param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]\n        for param in signature.parameters.values()\n    )\n    if support_var_arg:  # forward all arguments to from_config, if from_config accepts them\n        ret = from_config_func(*args, **kwargs)\n    else:\n        # forward supported arguments to from_config\n        supported_arg_names = set(signature.parameters.keys())\n        extra_kwargs = {}\n        for name in list(kwargs.keys()):\n            if name not in supported_arg_names:\n                extra_kwargs[name] = kwargs.pop(name)\n        ret = from_config_func(*args, **kwargs)\n        # forward the other arguments to __init__\n        ret.update(extra_kwargs)\n    return ret"
  },
  {
    "path": "llava/model/semsam/utils/misc.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py\n\n# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Modified by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\n\n\"\"\"\nMisc functions, including distributed helpers.\n\nMostly copy-paste from torchvision references.\n\"\"\"\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\nimport torchvision\nfrom torch import Tensor\n\nfrom utils.constants import *\n\ndef get_iou(gt_masks, pred_masks, ignore_label=-1):\n    rev_ignore_mask = ~(gt_masks == ignore_label)\n    gt_masks = gt_masks.bool()\n    n,h,w = gt_masks.shape\n    intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)\n    union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)\n    ious = (intersection / union)\n    return ious\n\ndef _max_by_axis(the_list):\n    # type: (List[List[int]]) -> List[int]\n    maxes = the_list[0]\n    for sublist in the_list[1:]:\n        for index, item in enumerate(sublist):\n            maxes[index] = max(maxes[index], item)\n    return maxes\n\nclass NestedTensor(object):\n    def __init__(self, tensors, mask: Optional[Tensor]):\n        self.tensors = tensors\n        self.mask = mask\n\n    def to(self, device):\n        # type: (Device) -> NestedTensor # noqa\n        cast_tensor = self.tensors.to(device)\n        mask = self.mask\n        if mask is not None:\n            assert mask is not None\n            cast_mask = mask.to(device)\n        else:\n            cast_mask = None\n        return NestedTensor(cast_tensor, cast_mask)\n\n    def decompose(self):\n        return self.tensors, self.mask\n\n    def __repr__(self):\n        return str(self.tensors)\n\ndef nested_tensor_from_tensor_list(tensor_list: List[Tensor]):\n    # TODO make this more general\n    if tensor_list[0].ndim == 3:\n        if torchvision._is_tracing():\n            # nested_tensor_from_tensor_list() does not export well to ONNX\n            # call _onnx_nested_tensor_from_tensor_list() instead\n            return _onnx_nested_tensor_from_tensor_list(tensor_list)\n\n        # TODO make it support different-sized images\n        max_size = _max_by_axis([list(img.shape) for img in tensor_list])\n        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))\n        batch_shape = [len(tensor_list)] + max_size\n        b, c, h, w = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)\n        for img, pad_img, m in zip(tensor_list, tensor, mask):\n            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n            m[: img.shape[1], : img.shape[2]] = False\n    elif tensor_list[0].ndim == 2:\n        if torchvision._is_tracing():\n            # nested_tensor_from_tensor_list() does not export well to ONNX\n            # call _onnx_nested_tensor_from_tensor_list() instead\n            return _onnx_nested_tensor_from_tensor_list(tensor_list)\n\n        # TODO make it support different-sized images\n        max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])\n        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))\n        batch_shape = [len(tensor_list)] + max_size\n        b, c, l = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.ones((b, l), dtype=torch.bool, device=device)\n        for txt, pad_txt, m in zip(tensor_list, tensor, mask):\n            pad_txt[: txt.shape[0], : txt.shape[1]] = txt\n            m[: txt.shape[1]] = False\n    else:\n        raise ValueError(\"not supported\")\n    return NestedTensor(tensor, mask)\n\ndef _collate_and_pad_divisibility(tensor_list: list, div=32):\n    max_size = []\n    for i in range(tensor_list[0].dim()):\n        max_size_i = torch.max(\n            torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)\n        ).to(torch.int64)\n        max_size.append(max_size_i)\n    max_size = tuple(max_size)\n\n    c,h,w = max_size\n    pad_h = (div - h % div) if h % div != 0 else 0\n    pad_w = (div - w % div) if w % div != 0 else 0\n    max_size = (c,h+pad_h,w+pad_w)\n    \n    # work around for\n    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n    # m[: img.shape[1], :img.shape[2]] = False\n    # which is not yet supported in onnx\n    padded_imgs = []\n    padded_masks = []\n    for img in tensor_list:\n        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]\n        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))\n        padded_imgs.append(padded_img)\n\n        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)\n        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), \"constant\", 1)\n        padded_masks.append(padded_mask.to(torch.bool))\n    \n    return padded_imgs\n\n# _onnx_nested_tensor_from_tensor_list() is an implementation of\n# nested_tensor_from_tensor_list() that is supported by ONNX tracing.\n@torch.jit.unused\ndef _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:\n    max_size = []\n    for i in range(tensor_list[0].dim()):\n        max_size_i = torch.max(\n            torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)\n        ).to(torch.int64)\n        max_size.append(max_size_i)\n    max_size = tuple(max_size)\n\n    # work around for\n    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n    # m[: img.shape[1], :img.shape[2]] = False\n    # which is not yet supported in onnx\n    padded_imgs = []\n    padded_masks = []\n    for img in tensor_list:\n        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]\n        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))\n        padded_imgs.append(padded_img)\n\n        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)\n        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), \"constant\", 1)\n        padded_masks.append(padded_mask.to(torch.bool))\n\n    tensor = torch.stack(padded_imgs)\n    mask = torch.stack(padded_masks)\n\n    return NestedTensor(tensor, mask=mask)\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_class_names(name, background=True):\n    if name is None:\n        return None\n    if 'refcoco' in name:\n        class_names = [\"none\"]\n    elif 'pascal' in name:\n        class_names = PASCAL_CLASSES_PART + [\"background\"]\n    elif 'sam' in name:\n        class_names = ['foreground'] + [\"background\"]\n    elif 'coco' in name and 'pan' not in name:\n        class_names = COCO_INSTANCE_CLASSES + [\"background\"]\n    elif 'coco' in name:\n        class_names = COCO_PANOPTIC_CLASSES + [\"background\"]\n    elif 'ade20k_full' in name:\n        class_names = ADE20K_847 + [\"background\"]\n    elif 'ade' in name:\n        class_names = ADE_PANOPTIC_CLASSES + [\"background\"]\n    elif 'voc' in name:\n        class_names = PASCAL_CLASSES + [\"background\"]\n    elif 'vlp' in name:\n        class_names = [\"noun\"]\n    elif 'tsv' in name:\n        class_names = [\"noun\"]\n    elif 'phrasecut' in name:\n        class_names = [\"noun\"]\n    elif 'mapillary' in name:\n        class_names =MAPILLARY_VISTAS_SEM_SEG_CATEGORIES\n    elif 'openimage' in name:\n        class_names = [\"noun\"]\n    elif 'imagenet' in name:\n        class_names = IMAGENET_CLASSES\n    elif 'context_459' in name:\n        class_names = PASCAL_CONTEXT_459 + [\"background\"]\n    elif 'context_59' in name:\n        class_names = PASCAL_CONTEXT_59 + [\"background\"]\n    elif 'context_33' in name:\n        class_names = PASCAL_CONTEXT_33\n    elif 'sunrgbd_37' in name:\n        class_names = SUN_RGBD_37\n    elif 'scannet_41' in name:\n        class_names = SCAN_40\n    elif 'scannet_38' in name:\n        class_names = SCAN_37\n    elif 'scannet_21' in name:\n        class_names = SCAN_20\n    elif 'object365' in name:\n        class_names = OBJECT365\n    elif 'lvis' in name:\n        class_names = LVIS_CATEGORIES\n    elif 'seginw' in name:\n        class_names = SEGINW_CATEGORIES[name.replace('_train', '').replace('_val', '')] + [\"background\"]\n    elif name == 'cityscapes_fine_sem_seg_val':\n        class_names = CITYSCAPES\n    elif name == 'cityscapes_fine_instance_seg_val':\n        class_names = CITYSCAPES_THING + [\"background\"]\n    elif name in ['cityscapes_fine_panoptic_val', 'cityscapes_fine_panoptic_train']:\n        class_names = CITYSCAPES + [\"background\"]\n    elif name == 'bdd10k_val_sem_seg':\n        class_names = BDD_SEM\n    elif name == 'bdd10k_40_panoptic_val':\n        class_names = BDD_PANO\n    else:\n        class_names=[\"none\"]\n        # assert False, \"text dataset name {} is not defined\".format(name)\n\n    if background == False and \"background\" in class_names:\n        class_names.pop(class_names.index(\"background\"))\n\n    return class_names\n\n# TODO: add background to \n# def get_class_names(name):\n#     if name is None:\n#         return None\n#     elif 'refcoco' in name:\n#         return [\"background\"]\n#     elif 'coco' in name:\n#         return COCO_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'ade20k_full' in name:\n#         return ADE20K_847 + [\"background\"]\n#     elif 'ade' in name:\n#         return ADE_PANOPTIC_CLASSES + [\"background\"]\n#     elif 'scannet_41' in name:\n#         return SCAN_40\n#     elif 'scannet_21' in name:\n#         return SCAN_20\n#     elif 'sun' in name:\n#         return SUN_RGBD_37\n#     elif name == 'cityscapes_fine_sem_seg_val':\n#         return CITYSCAPES + [\"background\"]\n#     elif name == 'cityscapes_fine_instance_seg_val':\n#         return CITYSCAPES_THING + [\"background\"]\n#     elif name in ['cityscapes_fine_panoptic_val']:\n#         return CITYSCAPES + [\"background\"]\n#     elif name == 'bdd10k_val_sem_seg':\n#         return BDD_SEM + [\"background\"]\n#     elif name == 'bdd10k_40_panoptic_val':\n#         return BDD_PANO + [\"background\"]\n#     elif 'vlp' in name:\n#         return [\"background\"]\n#     else:\n#         assert False, \"text dataset name {} is not defined\".format(name)\n"
  },
  {
    "path": "llava/model/utils.py",
    "content": "from transformers import AutoConfig\n\n\ndef auto_upgrade(config):\n    cfg = AutoConfig.from_pretrained(config)\n    if 'llava' in config and 'llava' not in cfg.model_type:\n        assert cfg.model_type == 'llama'\n        print(\"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.\")\n        print(\"You must upgrade the checkpoint to the new code base (this can be done automatically).\")\n        confirm = input(\"Please confirm that you want to upgrade the checkpoint. [Y/N]\")\n        if confirm.lower() in [\"y\", \"yes\"]:\n            print(\"Upgrading checkpoint...\")\n            assert len(cfg.architectures) == 1\n            setattr(cfg.__class__, \"model_type\", \"llava\")\n            cfg.architectures[0] = 'LlavaLlamaForCausalLM'\n            cfg.save_pretrained(config)\n            print(\"Checkpoint upgraded.\")\n        else:\n            print(\"Checkpoint upgrade aborted.\")\n            exit(1)\n"
  },
  {
    "path": "llava/serve/__init__.py",
    "content": ""
  },
  {
    "path": "llava/serve/cli.py",
    "content": "import argparse\nimport torch\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom llava.conversation import conv_templates, SeparatorStyle\nfrom llava.model.builder import load_pretrained_model\nfrom llava.utils import disable_torch_init\nfrom llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n\nfrom PIL import Image\n\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom transformers import TextStreamer\n\n\ndef load_image(image_file):\n    if image_file.startswith('http') or image_file.startswith('https'):\n        response = requests.get(image_file)\n        image = Image.open(BytesIO(response.content)).convert('RGB')\n    else:\n        image = Image.open(image_file).convert('RGB')\n    return image\n\n\ndef main(args):\n    # Model\n    disable_torch_init()\n\n    model_name = get_model_name_from_path(args.model_path)\n    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)\n\n    if 'llama-2' in model_name.lower():\n        conv_mode = \"llava_llama_2\"\n    elif \"v1\" in model_name.lower():\n        conv_mode = \"llava_v1\"\n    elif \"mpt\" in model_name.lower():\n        conv_mode = \"mpt\"\n    else:\n        conv_mode = \"llava_v0\"\n\n    if args.conv_mode is not None and conv_mode != args.conv_mode:\n        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))\n    else:\n        args.conv_mode = conv_mode\n\n    conv = conv_templates[args.conv_mode].copy()\n    if \"mpt\" in model_name.lower():\n        roles = ('user', 'assistant')\n    else:\n        roles = conv.roles\n\n    image = load_image(args.image_file)\n    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n\n    while True:\n        try:\n            inp = input(f\"{roles[0]}: \")\n        except EOFError:\n            inp = \"\"\n        if not inp:\n            print(\"exit...\")\n            break\n\n        print(f\"{roles[1]}: \", end=\"\")\n\n        if image is not None:\n            # first message\n            if model.config.mm_use_im_start_end:\n                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n            else:\n                inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n            conv.append_message(conv.roles[0], inp)\n            image = None\n        else:\n            # later messages\n            conv.append_message(conv.roles[0], inp)\n        conv.append_message(conv.roles[1], None)\n        prompt = conv.get_prompt()\n\n        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n        keywords = [stop_str]\n        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n\n        with torch.inference_mode():\n            output_ids = model.generate(\n                input_ids,\n                images=image_tensor,\n                do_sample=True,\n                temperature=0.2,\n                max_new_tokens=1024,\n                streamer=streamer,\n                use_cache=True,\n                stopping_criteria=[stopping_criteria])\n\n        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n        conv.messages[-1][-1] = outputs\n\n        if args.debug:\n            print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-path\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--model-base\", type=str, default=None)\n    parser.add_argument(\"--image-file\", type=str, required=True)\n    parser.add_argument(\"--num-gpus\", type=int, default=1)\n    parser.add_argument(\"--conv-mode\", type=str, default=None)\n    parser.add_argument(\"--temperature\", type=float, default=0.2)\n    parser.add_argument(\"--max-new-tokens\", type=int, default=512)\n    parser.add_argument(\"--load-8bit\", action=\"store_true\")\n    parser.add_argument(\"--load-4bit\", action=\"store_true\")\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "llava/serve/controller.py",
    "content": "\"\"\"\nA controller manages distributed workers.\nIt sends worker addresses to clients.\n\"\"\"\nimport argparse\nimport asyncio\nimport dataclasses\nfrom enum import Enum, auto\nimport json\nimport logging\nimport time\nfrom typing import List, Union\nimport threading\n\nfrom fastapi import FastAPI, Request\nfrom fastapi.responses import StreamingResponse\nimport numpy as np\nimport requests\nimport uvicorn\n\nfrom llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION\nfrom llava.utils import build_logger, server_error_msg\n\n\nlogger = build_logger(\"controller\", \"controller.log\")\n\n\nclass DispatchMethod(Enum):\n    LOTTERY = auto()\n    SHORTEST_QUEUE = auto()\n\n    @classmethod\n    def from_str(cls, name):\n        if name == \"lottery\":\n            return cls.LOTTERY\n        elif name == \"shortest_queue\":\n            return cls.SHORTEST_QUEUE\n        else:\n            raise ValueError(f\"Invalid dispatch method\")\n\n\n@dataclasses.dataclass\nclass WorkerInfo:\n    model_names: List[str]\n    speed: int\n    queue_length: int\n    check_heart_beat: bool\n    last_heart_beat: str\n\n\ndef heart_beat_controller(controller):\n    while True:\n        time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)\n        controller.remove_stable_workers_by_expiration()\n\n\nclass Controller:\n    def __init__(self, dispatch_method: str):\n        # Dict[str -> WorkerInfo]\n        self.worker_info = {}\n        self.dispatch_method = DispatchMethod.from_str(dispatch_method)\n\n        self.heart_beat_thread = threading.Thread(\n            target=heart_beat_controller, args=(self,))\n        self.heart_beat_thread.start()\n\n        logger.info(\"Init controller\")\n\n    def register_worker(self, worker_name: str, check_heart_beat: bool,\n                        worker_status: dict):\n        if worker_name not in self.worker_info:\n            logger.info(f\"Register a new worker: {worker_name}\")\n        else:\n            logger.info(f\"Register an existing worker: {worker_name}\")\n\n        if not worker_status:\n            worker_status = self.get_worker_status(worker_name)\n        if not worker_status:\n            return False\n\n        self.worker_info[worker_name] = WorkerInfo(\n            worker_status[\"model_names\"], worker_status[\"speed\"], worker_status[\"queue_length\"],\n            check_heart_beat, time.time())\n\n        logger.info(f\"Register done: {worker_name}, {worker_status}\")\n        return True\n\n    def get_worker_status(self, worker_name: str):\n        try:\n            r = requests.post(worker_name + \"/worker_get_status\", timeout=5)\n        except requests.exceptions.RequestException as e:\n            logger.error(f\"Get status fails: {worker_name}, {e}\")\n            return None\n\n        if r.status_code != 200:\n            logger.error(f\"Get status fails: {worker_name}, {r}\")\n            return None\n\n        return r.json()\n\n    def remove_worker(self, worker_name: str):\n        del self.worker_info[worker_name]\n\n    def refresh_all_workers(self):\n        old_info = dict(self.worker_info)\n        self.worker_info = {}\n\n        for w_name, w_info in old_info.items():\n            if not self.register_worker(w_name, w_info.check_heart_beat, None):\n                logger.info(f\"Remove stale worker: {w_name}\")\n\n    def list_models(self):\n        model_names = set()\n\n        for w_name, w_info in self.worker_info.items():\n            model_names.update(w_info.model_names)\n\n        return list(model_names)\n\n    def get_worker_address(self, model_name: str):\n        if self.dispatch_method == DispatchMethod.LOTTERY:\n            worker_names = []\n            worker_speeds = []\n            for w_name, w_info in self.worker_info.items():\n                if model_name in w_info.model_names:\n                    worker_names.append(w_name)\n                    worker_speeds.append(w_info.speed)\n            worker_speeds = np.array(worker_speeds, dtype=np.float32)\n            norm = np.sum(worker_speeds)\n            if norm < 1e-4:\n                return \"\"\n            worker_speeds = worker_speeds / norm\n            if True:  # Directly return address\n                pt = np.random.choice(np.arange(len(worker_names)),\n                    p=worker_speeds)\n                worker_name = worker_names[pt]\n                return worker_name\n\n            # Check status before returning\n            while True:\n                pt = np.random.choice(np.arange(len(worker_names)),\n                    p=worker_speeds)\n                worker_name = worker_names[pt]\n\n                if self.get_worker_status(worker_name):\n                    break\n                else:\n                    self.remove_worker(worker_name)\n                    worker_speeds[pt] = 0\n                    norm = np.sum(worker_speeds)\n                    if norm < 1e-4:\n                        return \"\"\n                    worker_speeds = worker_speeds / norm\n                    continue\n            return worker_name\n        elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:\n            worker_names = []\n            worker_qlen = []\n            for w_name, w_info in self.worker_info.items():\n                if model_name in w_info.model_names:\n                    worker_names.append(w_name)\n                    worker_qlen.append(w_info.queue_length / w_info.speed)\n            if len(worker_names) == 0:\n                return \"\"\n            min_index = np.argmin(worker_qlen)\n            w_name = worker_names[min_index]\n            self.worker_info[w_name].queue_length += 1\n            logger.info(f\"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}\")\n            return w_name\n        else:\n            raise ValueError(f\"Invalid dispatch method: {self.dispatch_method}\")\n\n    def receive_heart_beat(self, worker_name: str, queue_length: int):\n        if worker_name not in self.worker_info:\n            logger.info(f\"Receive unknown heart beat. {worker_name}\")\n            return False\n\n        self.worker_info[worker_name].queue_length = queue_length\n        self.worker_info[worker_name].last_heart_beat = time.time()\n        logger.info(f\"Receive heart beat. {worker_name}\")\n        return True\n\n    def remove_stable_workers_by_expiration(self):\n        expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION\n        to_delete = []\n        for worker_name, w_info in self.worker_info.items():\n            if w_info.check_heart_beat and w_info.last_heart_beat < expire:\n                to_delete.append(worker_name)\n\n        for worker_name in to_delete:\n            self.remove_worker(worker_name)\n\n    def worker_api_generate_stream(self, params):\n        worker_addr = self.get_worker_address(params[\"model\"])\n        if not worker_addr:\n            logger.info(f\"no worker: {params['model']}\")\n            ret = {\n                \"text\": server_error_msg,\n                \"error_code\": 2,\n            }\n            yield json.dumps(ret).encode() + b\"\\0\"\n\n        try:\n            response = requests.post(worker_addr + \"/worker_generate_stream\",\n                json=params, stream=True, timeout=5)\n            for chunk in response.iter_lines(decode_unicode=False, delimiter=b\"\\0\"):\n                if chunk:\n                    yield chunk + b\"\\0\"\n        except requests.exceptions.RequestException as e:\n            logger.info(f\"worker timeout: {worker_addr}\")\n            ret = {\n                \"text\": server_error_msg,\n                \"error_code\": 3,\n            }\n            yield json.dumps(ret).encode() + b\"\\0\"\n\n\n    # Let the controller act as a worker to achieve hierarchical\n    # management. This can be used to connect isolated sub networks.\n    def worker_api_get_status(self):\n        model_names = set()\n        speed = 0\n        queue_length = 0\n\n        for w_name in self.worker_info:\n            worker_status = self.get_worker_status(w_name)\n            if worker_status is not None:\n                model_names.update(worker_status[\"model_names\"])\n                speed += worker_status[\"speed\"]\n                queue_length += worker_status[\"queue_length\"]\n\n        return {\n            \"model_names\": list(model_names),\n            \"speed\": speed,\n            \"queue_length\": queue_length,\n        }\n\n\napp = FastAPI()\n\n\n@app.post(\"/register_worker\")\nasync def register_worker(request: Request):\n    data = await request.json()\n    controller.register_worker(\n        data[\"worker_name\"], data[\"check_heart_beat\"],\n        data.get(\"worker_status\", None))\n\n\n@app.post(\"/refresh_all_workers\")\nasync def refresh_all_workers():\n    models = controller.refresh_all_workers()\n\n\n@app.post(\"/list_models\")\nasync def list_models():\n    models = controller.list_models()\n    return {\"models\": models}\n\n\n@app.post(\"/get_worker_address\")\nasync def get_worker_address(request: Request):\n    data = await request.json()\n    addr = controller.get_worker_address(data[\"model\"])\n    return {\"address\": addr}\n\n\n@app.post(\"/receive_heart_beat\")\nasync def receive_heart_beat(request: Request):\n    data = await request.json()\n    exist = controller.receive_heart_beat(\n        data[\"worker_name\"], data[\"queue_length\"])\n    return {\"exist\": exist}\n\n\n@app.post(\"/worker_generate_stream\")\nasync def worker_api_generate_stream(request: Request):\n    params = await request.json()\n    generator = controller.worker_api_generate_stream(params)\n    return StreamingResponse(generator)\n\n\n@app.post(\"/worker_get_status\")\nasync def worker_api_get_status(request: Request):\n    return controller.worker_api_get_status()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"localhost\")\n    parser.add_argument(\"--port\", type=int, default=21001)\n    parser.add_argument(\"--dispatch-method\", type=str, choices=[\n        \"lottery\", \"shortest_queue\"], default=\"shortest_queue\")\n    args = parser.parse_args()\n    logger.info(f\"args: {args}\")\n\n    controller = Controller(args.dispatch_method)\n    uvicorn.run(app, host=args.host, port=args.port, log_level=\"info\")\n"
  },
  {
    "path": "llava/serve/gradio_web_server.py",
    "content": "import argparse\nimport datetime\nimport json\nimport os\nimport time\n\nimport gradio as gr\nimport requests\n\nfrom llava.conversation import (default_conversation, conv_templates,\n                                   SeparatorStyle)\nfrom llava.constants import LOGDIR\nfrom llava.utils import (build_logger, server_error_msg,\n    violates_moderation, moderation_msg)\nimport hashlib\n\n\nlogger = build_logger(\"gradio_web_server\", \"gradio_web_server.log\")\n\nheaders = {\"User-Agent\": \"LLaVA Client\"}\n\nno_change_btn = gr.Button.update()\nenable_btn = gr.Button.update(interactive=True)\ndisable_btn = gr.Button.update(interactive=False)\n\npriority = {\n    \"vicuna-13b\": \"aaaaaaa\",\n    \"koala-13b\": \"aaaaaab\",\n}\n\n\ndef get_conv_log_filename():\n    t = datetime.datetime.now()\n    name = os.path.join(LOGDIR, f\"{t.year}-{t.month:02d}-{t.day:02d}-conv.json\")\n    return name\n\n\ndef get_model_list():\n    ret = requests.post(args.controller_url + \"/refresh_all_workers\")\n    assert ret.status_code == 200\n    ret = requests.post(args.controller_url + \"/list_models\")\n    models = ret.json()[\"models\"]\n    models.sort(key=lambda x: priority.get(x, x))\n    logger.info(f\"Models: {models}\")\n    return models\n\n\nget_window_url_params = \"\"\"\nfunction() {\n    const params = new URLSearchParams(window.location.search);\n    url_params = Object.fromEntries(params);\n    console.log(url_params);\n    return url_params;\n    }\n\"\"\"\n\n\ndef load_demo(url_params, request: gr.Request):\n    logger.info(f\"load_demo. ip: {request.client.host}. params: {url_params}\")\n\n    dropdown_update = gr.Dropdown.update(visible=True)\n    if \"model\" in url_params:\n        model = url_params[\"model\"]\n        if model in models:\n            dropdown_update = gr.Dropdown.update(\n                value=model, visible=True)\n\n    state = default_conversation.copy()\n    return (state,\n            dropdown_update,\n            gr.Chatbot.update(visible=True),\n            gr.Textbox.update(visible=True),\n            gr.Button.update(visible=True),\n            gr.Row.update(visible=True),\n            gr.Accordion.update(visible=True))\n\n\ndef load_demo_refresh_model_list(request: gr.Request):\n    logger.info(f\"load_demo. ip: {request.client.host}\")\n    models = get_model_list()\n    state = default_conversation.copy()\n    return (state, gr.Dropdown.update(\n               choices=models,\n               value=models[0] if len(models) > 0 else \"\"),\n            gr.Chatbot.update(visible=True),\n            gr.Textbox.update(visible=True),\n            gr.Button.update(visible=True),\n            gr.Row.update(visible=True),\n            gr.Accordion.update(visible=True))\n\n\ndef vote_last_response(state, vote_type, model_selector, request: gr.Request):\n    with open(get_conv_log_filename(), \"a\") as fout:\n        data = {\n            \"tstamp\": round(time.time(), 4),\n            \"type\": vote_type,\n            \"model\": model_selector,\n            \"state\": state.dict(),\n            \"ip\": request.client.host,\n        }\n        fout.write(json.dumps(data) + \"\\n\")\n\n\ndef upvote_last_response(state, model_selector, request: gr.Request):\n    logger.info(f\"upvote. ip: {request.client.host}\")\n    vote_last_response(state, \"upvote\", model_selector, request)\n    return (\"\",) + (disable_btn,) * 3\n\n\ndef downvote_last_response(state, model_selector, request: gr.Request):\n    logger.info(f\"downvote. ip: {request.client.host}\")\n    vote_last_response(state, \"downvote\", model_selector, request)\n    return (\"\",) + (disable_btn,) * 3\n\n\ndef flag_last_response(state, model_selector, request: gr.Request):\n    logger.info(f\"flag. ip: {request.client.host}\")\n    vote_last_response(state, \"flag\", model_selector, request)\n    return (\"\",) + (disable_btn,) * 3\n\n\ndef regenerate(state, image_process_mode, request: gr.Request):\n    logger.info(f\"regenerate. ip: {request.client.host}\")\n    state.messages[-1][-1] = None\n    prev_human_msg = state.messages[-2]\n    if type(prev_human_msg[1]) in (tuple, list):\n        prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)\n    state.skip_next = False\n    return (state, state.to_gradio_chatbot(), \"\", None) + (disable_btn,) * 5\n\n\ndef clear_history(request: gr.Request):\n    logger.info(f\"clear_history. ip: {request.client.host}\")\n    state = default_conversation.copy()\n    return (state, state.to_gradio_chatbot(), \"\", None) + (disable_btn,) * 5\n\n\ndef add_text(state, text, image, image_process_mode, request: gr.Request):\n    logger.info(f\"add_text. ip: {request.client.host}. len: {len(text)}\")\n    if len(text) <= 0 and image is None:\n        state.skip_next = True\n        return (state, state.to_gradio_chatbot(), \"\", None) + (no_change_btn,) * 5\n    if args.moderate:\n        flagged = violates_moderation(text)\n        if flagged:\n            state.skip_next = True\n            return (state, state.to_gradio_chatbot(), moderation_msg, None) + (\n                no_change_btn,) * 5\n\n    text = text[:1536]  # Hard cut-off\n    if image is not None:\n        text = text[:1200]  # Hard cut-off for images\n        if '<image>' not in text:\n            # text = '<Image><image></Image>' + text\n            text = text + '\\n<image>'\n        text = (text, image, image_process_mode)\n        if len(state.get_images(return_pil=True)) > 0:\n            state = default_conversation.copy()\n    state.append_message(state.roles[0], text)\n    state.append_message(state.roles[1], None)\n    state.skip_next = False\n    return (state, state.to_gradio_chatbot(), \"\", None) + (disable_btn,) * 5\n\n\ndef http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):\n    logger.info(f\"http_bot. ip: {request.client.host}\")\n    start_tstamp = time.time()\n    model_name = model_selector\n\n    if state.skip_next:\n        # This generate call is skipped due to invalid inputs\n        yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5\n        return\n\n    if len(state.messages) == state.offset + 2:\n        # First round of conversation\n        if \"llava\" in model_name.lower():\n            if 'llama-2' in model_name.lower():\n                template_name = \"llava_llama_2\"\n            elif \"v1\" in model_name.lower():\n                if 'mmtag' in model_name.lower():\n                    template_name = \"v1_mmtag\"\n                elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n                    template_name = \"v1_mmtag\"\n                else:\n                    template_name = \"llava_v1\"\n            elif \"mpt\" in model_name.lower():\n                template_name = \"mpt\"\n            else:\n                if 'mmtag' in model_name.lower():\n                    template_name = \"v0_mmtag\"\n                elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n                    template_name = \"v0_mmtag\"\n                else:\n                    template_name = \"llava_v0\"\n        elif \"mpt\" in model_name:\n            template_name = \"mpt_text\"\n        elif \"llama-2\" in model_name:\n            template_name = \"llama_2\"\n        else:\n            template_name = \"vicuna_v1\"\n        new_state = conv_templates[template_name].copy()\n        new_state.append_message(new_state.roles[0], state.messages[-2][1])\n        new_state.append_message(new_state.roles[1], None)\n        state = new_state\n\n    # Query worker address\n    controller_url = args.controller_url\n    ret = requests.post(controller_url + \"/get_worker_address\",\n            json={\"model\": model_name})\n    worker_addr = ret.json()[\"address\"]\n    logger.info(f\"model_name: {model_name}, worker_addr: {worker_addr}\")\n\n    # No available worker\n    if worker_addr == \"\":\n        state.messages[-1][-1] = server_error_msg\n        yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)\n        return\n\n    # Construct prompt\n    prompt = state.get_prompt()\n\n    all_images = state.get_images(return_pil=True)\n    all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]\n    for image, hash in zip(all_images, all_image_hash):\n        t = datetime.datetime.now()\n        filename = os.path.join(LOGDIR, \"serve_images\", f\"{t.year}-{t.month:02d}-{t.day:02d}\", f\"{hash}.jpg\")\n        if not os.path.isfile(filename):\n            os.makedirs(os.path.dirname(filename), exist_ok=True)\n            image.save(filename)\n\n    # Make requests\n    pload = {\n        \"model\": model_name,\n        \"prompt\": prompt,\n        \"temperature\": float(temperature),\n        \"top_p\": float(top_p),\n        \"max_new_tokens\": min(int(max_new_tokens), 1536),\n        \"stop\": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,\n        \"images\": f'List of {len(state.get_images())} images: {all_image_hash}',\n    }\n    logger.info(f\"==== request ====\\n{pload}\")\n\n    pload['images'] = state.get_images()\n\n    state.messages[-1][-1] = \"▌\"\n    yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5\n\n    try:\n        # Stream output\n        response = requests.post(worker_addr + \"/worker_generate_stream\",\n            headers=headers, json=pload, stream=True, timeout=10)\n        for chunk in response.iter_lines(decode_unicode=False, delimiter=b\"\\0\"):\n            if chunk:\n                data = json.loads(chunk.decode())\n                if data[\"error_code\"] == 0:\n                    output = data[\"text\"][len(prompt):].strip()\n                    state.messages[-1][-1] = output + \"▌\"\n                    yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5\n                else:\n                    output = data[\"text\"] + f\" (error_code: {data['error_code']})\"\n                    state.messages[-1][-1] = output\n                    yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)\n                    return\n                time.sleep(0.03)\n    except requests.exceptions.RequestException as e:\n        state.messages[-1][-1] = server_error_msg\n        yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)\n        return\n\n    state.messages[-1][-1] = state.messages[-1][-1][:-1]\n    yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5\n\n    finish_tstamp = time.time()\n    logger.info(f\"{output}\")\n\n    with open(get_conv_log_filename(), \"a\") as fout:\n        data = {\n            \"tstamp\": round(finish_tstamp, 4),\n            \"type\": \"chat\",\n            \"model\": model_name,\n            \"start\": round(start_tstamp, 4),\n            \"finish\": round(start_tstamp, 4),\n            \"state\": state.dict(),\n            \"images\": all_image_hash,\n            \"ip\": request.client.host,\n        }\n        fout.write(json.dumps(data) + \"\\n\")\n\ntitle_markdown = (\"\"\"\n# 🌋 LLaVA: Large Language and Vision Assistant\n[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)\n\"\"\")\n\ntos_markdown = (\"\"\"\n### Terms of use\nBy using this service, users are required to agree to the following terms:\nThe service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.\nPlease click the \"Flag\" button if you get any inappropriate answer! We will collect those to keep improving our moderator.\nFor an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n\"\"\")\n\n\nlearn_more_markdown = (\"\"\"\n### License\nThe service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n\"\"\")\n\n\ndef build_demo(embed_mode):\n    textbox = gr.Textbox(show_label=False, placeholder=\"Enter text and press ENTER\", visible=False, container=False)\n    with gr.Blocks(title=\"LLaVA\", theme=gr.themes.Base()) as demo:\n        state = gr.State()\n\n        if not embed_mode:\n            gr.Markdown(title_markdown)\n\n        with gr.Row():\n            with gr.Column(scale=3):\n                with gr.Row(elem_id=\"model_selector_row\"):\n                    model_selector = gr.Dropdown(\n                        choices=models,\n                        value=models[0] if len(models) > 0 else \"\",\n                        interactive=True,\n                        show_label=False,\n                        container=False)\n\n                imagebox = gr.Image(type=\"pil\")\n                image_process_mode = gr.Radio(\n                    [\"Crop\", \"Resize\", \"Pad\"],\n                    value=\"Crop\",\n                    label=\"Preprocess for non-square image\")\n\n                cur_dir = os.path.dirname(os.path.abspath(__file__))\n                gr.Examples(examples=[\n                    [f\"{cur_dir}/examples/extreme_ironing.jpg\", \"What is unusual about this image?\"],\n                    [f\"{cur_dir}/examples/waterview.jpg\", \"What are the things I should be cautious about when I visit here?\"],\n                ], inputs=[imagebox, textbox])\n\n                with gr.Accordion(\"Parameters\", open=False, visible=False) as parameter_row:\n                    temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label=\"Temperature\",)\n                    top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label=\"Top P\",)\n                    max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label=\"Max output tokens\",)\n\n            with gr.Column(scale=6):\n                chatbot = gr.Chatbot(elem_id=\"chatbot\", label=\"LLaVA Chatbot\", visible=False, height=550)\n                with gr.Row():\n                    with gr.Column(scale=8):\n                        textbox.render()\n                    with gr.Column(scale=1, min_width=60):\n                        submit_btn = gr.Button(value=\"Submit\", visible=False)\n                with gr.Row(visible=False) as button_row:\n                    upvote_btn = gr.Button(value=\"👍  Upvote\", interactive=False)\n                    downvote_btn = gr.Button(value=\"👎  Downvote\", interactive=False)\n                    flag_btn = gr.Button(value=\"⚠️  Flag\", interactive=False)\n                    #stop_btn = gr.Button(value=\"⏹️  Stop Generation\", interactive=False)\n                    regenerate_btn = gr.Button(value=\"🔄  Regenerate\", interactive=False)\n                    clear_btn = gr.Button(value=\"🗑️  Clear history\", interactive=False)\n\n        if not embed_mode:\n            gr.Markdown(tos_markdown)\n            gr.Markdown(learn_more_markdown)\n        url_params = gr.JSON(visible=False)\n\n        # Register listeners\n        btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]\n        upvote_btn.click(upvote_last_response,\n            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])\n        downvote_btn.click(downvote_last_response,\n            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])\n        flag_btn.click(flag_last_response,\n            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])\n        regenerate_btn.click(regenerate, [state, image_process_mode],\n            [state, chatbot, textbox, imagebox] + btn_list).then(\n            http_bot, [state, model_selector, temperature, top_p, max_output_tokens],\n            [state, chatbot] + btn_list)\n        clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)\n\n        textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list\n            ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],\n                   [state, chatbot] + btn_list)\n        submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list\n            ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],\n                   [state, chatbot] + btn_list)\n\n        if args.model_list_mode == \"once\":\n            demo.load(load_demo, [url_params], [state, model_selector,\n                chatbot, textbox, submit_btn, button_row, parameter_row],\n                _js=get_window_url_params)\n        elif args.model_list_mode == \"reload\":\n            demo.load(load_demo_refresh_model_list, None, [state, model_selector,\n                chatbot, textbox, submit_btn, button_row, parameter_row])\n        else:\n            raise ValueError(f\"Unknown model list mode: {args.model_list_mode}\")\n\n    return demo\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"0.0.0.0\")\n    parser.add_argument(\"--port\", type=int)\n    parser.add_argument(\"--controller-url\", type=str, default=\"http://localhost:21001\")\n    parser.add_argument(\"--concurrency-count\", type=int, default=8)\n    parser.add_argument(\"--model-list-mode\", type=str, default=\"once\",\n        choices=[\"once\", \"reload\"])\n    parser.add_argument(\"--share\", action=\"store_true\")\n    parser.add_argument(\"--moderate\", action=\"store_true\")\n    parser.add_argument(\"--embed\", action=\"store_true\")\n    args = parser.parse_args()\n    logger.info(f\"args: {args}\")\n\n    models = get_model_list()\n\n    logger.info(args)\n    demo = build_demo(args.embed)\n    demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,\n               api_open=False).launch(\n        server_name=args.host, server_port=args.port, share=args.share)\n"
  },
  {
    "path": "llava/serve/register_worker.py",
    "content": "\"\"\"\nManually register workers.\n\nUsage:\npython3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002\n\"\"\"\n\nimport argparse\n\nimport requests\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--controller-address\", type=str)\n    parser.add_argument(\"--worker-name\", type=str)\n    parser.add_argument(\"--check-heart-beat\", action=\"store_true\")\n    args = parser.parse_args()\n\n    url = args.controller_address + \"/register_worker\"\n    data = {\n        \"worker_name\": args.worker_name,\n        \"check_heart_beat\": args.check_heart_beat,\n        \"worker_status\": None,\n    }\n    r = requests.post(url, json=data)\n    assert r.status_code == 200\n"
  },
  {
    "path": "llava/serve/test_message.py",
    "content": "import argparse\nimport json\n\nimport requests\n\nfrom llava.conversation import default_conversation\n\n\ndef main():\n    if args.worker_address:\n        worker_addr = args.worker_address\n    else:\n        controller_addr = args.controller_address\n        ret = requests.post(controller_addr + \"/refresh_all_workers\")\n        ret = requests.post(controller_addr + \"/list_models\")\n        models = ret.json()[\"models\"]\n        models.sort()\n        print(f\"Models: {models}\")\n\n        ret = requests.post(controller_addr + \"/get_worker_address\",\n            json={\"model\": args.model_name})\n        worker_addr = ret.json()[\"address\"]\n        print(f\"worker_addr: {worker_addr}\")\n\n    if worker_addr == \"\":\n        return\n\n    conv = default_conversation.copy()\n    conv.append_message(conv.roles[0], args.message)\n    prompt = conv.get_prompt()\n\n    headers = {\"User-Agent\": \"LLaVA Client\"}\n    pload = {\n        \"model\": args.model_name,\n        \"prompt\": prompt,\n        \"max_new_tokens\": args.max_new_tokens,\n        \"temperature\": 0.7,\n        \"stop\": conv.sep,\n    }\n    response = requests.post(worker_addr + \"/worker_generate_stream\", headers=headers,\n            json=pload, stream=True)\n\n    print(prompt.replace(conv.sep, \"\\n\"), end=\"\")\n    for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b\"\\0\"):\n        if chunk:\n            data = json.loads(chunk.decode(\"utf-8\"))\n            output = data[\"text\"].split(conv.sep)[-1]\n            print(output, end=\"\\r\")\n    print(\"\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--controller-address\", type=str, default=\"http://localhost:21001\")\n    parser.add_argument(\"--worker-address\", type=str)\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--max-new-tokens\", type=int, default=32)\n    parser.add_argument(\"--message\", type=str, default=\n        \"Tell me a story with more than 1000 words.\")\n    args = parser.parse_args()\n\n    main()\n"
  },
  {
    "path": "llava/train/llama_flash_attn_monkey_patch.py",
    "content": "from typing import List, Optional, Tuple\nimport logging\n\nimport torch\nfrom torch import nn\n\nimport transformers\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb\n\nfrom einops import rearrange\n\ntry:\n    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func\nexcept ImportError:\n    from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func\nfrom flash_attn.bert_padding import unpad_input, pad_input\n\n\ndef forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.Tensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    \"\"\"Input shape: Batch x Time x Channel\n\n    attention_mask: [bsz, q_len]\n    \"\"\"\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = (\n        self.q_proj(hidden_states)\n        .view(bsz, q_len, self.num_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    key_states = (\n        self.k_proj(hidden_states)\n        .view(bsz, q_len, self.num_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    value_states = (\n        self.v_proj(hidden_states)\n        .view(bsz, q_len, self.num_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    # [bsz, q_len, nh, hd]\n    # [bsz, nh, q_len, hd]\n\n    kv_seq_len = key_states.shape[-2]\n    assert past_key_value is None, \"past_key_value is not supported\"\n\n    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n    query_states, key_states = apply_rotary_pos_emb(\n        query_states, key_states, cos, sin, position_ids\n    )\n    # [bsz, nh, t, hd]\n    assert not output_attentions, \"output_attentions is not supported\"\n    assert not use_cache, \"use_cache is not supported\"\n\n    # Flash attention codes from\n    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py\n\n    # transform the data into the format required by flash attention\n    qkv = torch.stack(\n        [query_states, key_states, value_states], dim=2\n    )  # [bsz, nh, 3, q_len, hd]\n    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]\n    # We have disabled _prepare_decoder_attention_mask in LlamaModel\n    # the attention_mask should be the same as the key_padding_mask\n    key_padding_mask = attention_mask\n\n    if key_padding_mask is None:\n        qkv = rearrange(qkv, \"b s ... -> (b s) ...\")\n        max_s = q_len\n        cu_q_lens = torch.arange(\n            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device\n        )\n        output = flash_attn_unpadded_qkvpacked_func(\n            qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True\n        )\n        output = rearrange(output, \"(b s) ... -> b s ...\", b=bsz)\n    else:\n        nheads = qkv.shape[-2]\n        x = rearrange(qkv, \"b s three h d -> b s (three h d)\")\n        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)\n        x_unpad = rearrange(\n            x_unpad, \"nnz (three h d) -> nnz three h d\", three=3, h=nheads\n        )\n        output_unpad = flash_attn_unpadded_qkvpacked_func(\n            x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True\n        )\n        output = rearrange(\n            pad_input(\n                rearrange(output_unpad, \"nnz h d -> nnz (h d)\"), indices, bsz, q_len\n            ),\n            \"b s (h d) -> b s h d\",\n            h=nheads,\n        )\n    return self.o_proj(rearrange(output, \"b s h d -> b s (h d)\")), None, None\n\n\n# Disable the transformation of the attention mask in LlamaModel as the flash attention\n# requires the attention mask to be the same as the key_padding_mask\ndef _prepare_decoder_attention_mask(\n    self, attention_mask, input_shape, inputs_embeds, past_key_values_length\n):\n    # [bsz, seq_len]\n    return attention_mask\n\n\ndef replace_llama_attn_with_flash_attn():\n    cuda_major, cuda_minor = torch.cuda.get_device_capability()\n    if cuda_major < 8:\n        logging.warning(\n            \"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.\"\n            \"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593\"\n        )\n    transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (\n        _prepare_decoder_attention_mask\n    )\n    transformers.models.llama.modeling_llama.LlamaAttention.forward = forward\n"
  },
  {
    "path": "llava/train/llava_trainer.py",
    "content": "import os\nimport torch\n\nfrom transformers import Trainer\nfrom typing import Optional\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                print(name, 'no ignore status')\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}\n    return to_return\n\n\nclass LLaVATrainer(Trainer):\n\n    def _save_checkpoint(self, model, trial, metrics=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n            checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n\n            run_dir = self._get_output_dir(trial=trial)\n            output_dir = os.path.join(run_dir, checkpoint_folder)\n\n            # Only save Adapter\n            keys_to_match = ['mm_projector']\n            if getattr(self.args, \"use_im_start_end\", False) or getattr(self.args, \"new_tokens\", False):\n                keys_to_match.extend(['embed_tokens', 'embed_in','lm_head'])\n            # import pdb; pdb.set_trace()\n            weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)\n\n            if self.args.local_rank == 0 or self.args.local_rank == -1:\n                self.model.config.save_pretrained(output_dir)\n                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n        else:\n            super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            pass\n        else:\n            super(LLaVATrainer, self)._save(output_dir, state_dict)\n"
  },
  {
    "path": "llava/train/llava_trainer_gd.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n# from transformers import Trainer\nfrom typing import Optional\nfrom transformers.trainer import *\nfrom datasets_os import build_train_dataloader\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                print(name, 'no ignore status')\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}\n    return to_return\n\n\nclass TrainerLLavaGD(Trainer):\n    \"\"\"\n    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.\n\n    Args:\n        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):\n            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.\n\n            <Tip>\n\n            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use\n            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers\n            models.\n\n            </Tip>\n\n        args ([`TrainingArguments`], *optional*):\n            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the\n            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.\n        data_collator (`DataCollator`, *optional*):\n            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will\n            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of\n            [`DataCollatorWithPadding`] otherwise.\n        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):\n            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the\n            `model.forward()` method are automatically removed.\n\n            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a\n            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a\n            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will\n            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally\n            sets the seed of the RNGs used.\n        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):\n             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the\n             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each\n             dataset prepending the dictionary key to the metric name.\n        tokenizer ([`PreTrainedTokenizerBase`], *optional*):\n            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the\n            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an\n            interrupted training or reuse the fine-tuned model.\n        model_init (`Callable[[], PreTrainedModel]`, *optional*):\n            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start\n            from a new instance of the model as given by this function.\n\n            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to\n            be able to choose different architectures according to hyper parameters (such as layer count, sizes of\n            inner layers, dropout probabilities etc).\n        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):\n            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return\n            a dictionary string to metric values.\n        callbacks (List of [`TrainerCallback`], *optional*):\n            A list of callbacks to customize the training loop. Will add those to the list of default callbacks\n            detailed in [here](callback).\n\n            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.\n        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple\n            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model\n            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.\n        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):\n            A function that preprocess the logits right before caching them at each evaluation step. Must take two\n            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made\n            by this function will be reflected in the predictions received by `compute_metrics`.\n\n            Note that the labels (second parameter) will be `None` if the dataset does not have them.\n\n    Important attributes:\n\n        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]\n          subclass.\n        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the\n          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,\n          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner\n          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.\n        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from\n          data parallelism, this means some of the model layers are split on different GPUs).\n        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set\n          to `False` if model parallel or deepspeed is used, or if the default\n          `TrainingArguments.place_model_on_device` is overridden to return `False` .\n        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while\n          in `train`)\n\n    \"\"\"\n\n    # Those are used as methods of the Trainer in examples.\n\n    def __init__(\n        self,\n        model: Union[PreTrainedModel, nn.Module] = None,\n        args: TrainingArguments = None,\n        data_collator: Optional[DataCollator] = None,\n        train_dataset: Optional[Dataset] = None,\n        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = None,\n        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        data_loader_args=None,\n        cfg=None,\n    ):\n        self.cfg=cfg\n        if args is None:\n            output_dir = \"tmp_trainer\"\n            logger.info(f\"No `TrainingArguments` passed, using `output_dir={output_dir}`.\")\n            args = TrainingArguments(output_dir=output_dir)\n        self.args = args\n        # Seed must be set before instantiating the model when using model\n        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)\n        self.hp_name = None\n        self.deepspeed = None\n        self.is_in_train = False\n        self.data_loader_args=data_loader_args\n        self.create_accelerator_and_postprocess()\n\n        # memory metrics - must set up as early as possible\n        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)\n        self._memory_tracker.start()\n\n        # set the correct log level depending on the node\n        log_level = args.get_process_log_level()\n        logging.set_verbosity(log_level)\n\n        # force device and distributed setup init explicitly\n        args._setup_devices\n\n        if model is None:\n            if model_init is not None:\n                self.model_init = model_init\n                model = self.call_model_init()\n            else:\n                raise RuntimeError(\"`Trainer` requires either a `model` or `model_init` argument\")\n        else:\n            if model_init is not None:\n                warnings.warn(\n                    \"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will\"\n                    \" overwrite your model when calling the `train` method. This will become a fatal error in the next\"\n                    \" release.\",\n                    FutureWarning,\n                )\n            self.model_init = model_init\n\n        if model.__class__.__name__ in MODEL_MAPPING_NAMES:\n            raise ValueError(\n                f\"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only \"\n                \"computes hidden states and does not accept any labels. You should choose a model with a head \"\n                \"suitable for your task like any of the `AutoModelForXxx` listed at \"\n                \"https://huggingface.co/docs/transformers/model_doc/auto.\"\n            )\n\n        if hasattr(model, \"is_parallelizable\") and model.is_parallelizable and model.model_parallel:\n            self.is_model_parallel = True\n        else:\n            self.is_model_parallel = False\n\n        if getattr(model, \"hf_device_map\", None) is not None:\n            devices = [device for device in set(model.hf_device_map.values()) if device not in [\"cpu\", \"disk\"]]\n            if len(devices) > 1:\n                self.is_model_parallel = True\n            else:\n                self.is_model_parallel = self.args.device != torch.device(devices[0])\n\n            # warn users\n            logger.info(\n                \"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set\"\n                \" to `True` to avoid any unexpected behavior such as device placement mismatching.\"\n            )\n\n        # At this stage the model is already loaded\n        if getattr(model, \"is_quantized\", False):\n            if getattr(model, \"_is_quantized_training_enabled\", False):\n                logger.info(\n                    \"The model is loaded in 8-bit precision. To train this model you need to add additional modules\"\n                    \" inside the model such as adapters using `peft` library and freeze the model weights. Please\"\n                    \" check \"\n                    \" the examples in https://github.com/huggingface/peft for more details.\"\n                )\n            else:\n                raise ValueError(\n                    \"The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit\"\n                    \" model, please make sure that you have installed `bitsandbytes>=0.37.0`. \"\n                )\n\n        # Setup Sharded DDP training\n        self.sharded_ddp = None\n        if len(args.sharded_ddp) > 0:\n            if self.is_deepspeed_enabled:\n                raise ValueError(\n                    \"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags.\"\n                )\n            if len(args.fsdp) > 0:\n                raise ValueError(\n                    \"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags.\"\n                )\n            if args.parallel_mode != ParallelMode.DISTRIBUTED:\n                raise ValueError(\"Using sharded DDP only works in distributed training.\")\n            elif not is_fairscale_available():\n                raise ImportError(\"Sharded DDP training requires fairscale: `pip install fairscale`.\")\n            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:\n                raise ImportError(\n                    \"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found \"\n                    f\"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`.\"\n                )\n            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.SIMPLE\n            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2\n            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3\n\n        self.fsdp = None\n        if len(args.fsdp) > 0:\n            if self.is_deepspeed_enabled:\n                raise ValueError(\n                    \"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags.\"\n                )\n            if not args.fsdp_config[\"xla\"] and args.parallel_mode != ParallelMode.DISTRIBUTED:\n                raise ValueError(\"Using fsdp only works in distributed training.\")\n\n            # dep_version_check(\"torch>=1.12.0\")\n            # Would have to update setup.py with torch>=1.12.0\n            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0\n            # below is the current alternative.\n            if version.parse(version.parse(torch.__version__).base_version) < version.parse(\"1.12.0\"):\n                raise ValueError(\"FSDP requires PyTorch >= 1.12.0\")\n\n            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy\n\n            if FSDPOption.FULL_SHARD in args.fsdp:\n                self.fsdp = ShardingStrategy.FULL_SHARD\n            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:\n                self.fsdp = ShardingStrategy.SHARD_GRAD_OP\n            elif FSDPOption.NO_SHARD in args.fsdp:\n                self.fsdp = ShardingStrategy.NO_SHARD\n\n            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE\n            if \"backward_prefetch\" in self.args.fsdp_config and \"backward_post\" in self.args.fsdp_config.get(\n                \"backward_prefetch\", []\n            ):\n                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST\n\n            self.forward_prefetch = False\n            if self.args.fsdp_config.get(\"forward_prefect\", False):\n                self.forward_prefetch = True\n\n            self.limit_all_gathers = False\n            if self.args.fsdp_config.get(\"limit_all_gathers\", False):\n                self.limit_all_gathers = True\n\n        # one place to sort out whether to place the model on device or not\n        # postpone switching model to cuda when:\n        # 1. MP - since we are trying to fit a much bigger than 1 gpu model\n        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,\n        #    and we only use deepspeed for training at the moment\n        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first\n        # 4. Sharded DDP - same as MP\n        # 5. FSDP - same as MP\n        self.place_model_on_device = args.place_model_on_device\n        if (\n            self.is_model_parallel\n            or self.is_deepspeed_enabled\n            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)\n            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])\n            or (self.fsdp is not None)\n            or self.is_fsdp_enabled\n        ):\n            self.place_model_on_device = False\n\n        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)\n        self.data_collator = data_collator if data_collator is not None else default_collator\n        self.train_dataset = train_dataset\n        self.eval_dataset = eval_dataset\n        self.tokenizer = tokenizer\n\n        # Quantized models doesn't support `.to` operation.\n        if self.place_model_on_device and not getattr(model, \"is_quantized\", False):\n            self._move_model_to_device(model, args.device)\n\n        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs\n        if self.is_model_parallel:\n            self.args._n_gpu = 1\n\n        # later use `self.model is self.model_wrapped` to check if it's wrapped or not\n        self.model_wrapped = model\n        self.model = model\n\n        self.compute_metrics = compute_metrics\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n        self.optimizer, self.lr_scheduler = optimizers\n        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):\n            raise RuntimeError(\n                \"Passing a `model_init` is incompatible with providing the `optimizers` argument. \"\n                \"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method.\"\n            )\n        if is_torch_tpu_available() and self.optimizer is not None:\n            for param in self.model.parameters():\n                model_device = param.device\n                break\n            for param_group in self.optimizer.param_groups:\n                if len(param_group[\"params\"]) > 0:\n                    optimizer_device = param_group[\"params\"][0].device\n                    break\n            if model_device != optimizer_device:\n                raise ValueError(\n                    \"The model and the optimizer parameters are not on the same device, which probably means you\"\n                    \" created an optimizer around your model **before** putting on the device and passing it to the\"\n                    \" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and\"\n                    \" `model.to(xm.xla_device())` is performed before the optimizer creation in your script.\"\n                )\n        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (\n            self.optimizer is not None or self.lr_scheduler is not None\n        ):\n            raise RuntimeError(\n                \"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled.\"\n                \"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method.\"\n            )\n        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)\n        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks\n        self.callback_handler = CallbackHandler(\n            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler\n        )\n        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)\n\n        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.\n        self._loggers_initialized = False\n\n        # Create clone of distant repo and output directory if needed\n        if self.args.push_to_hub:\n            self.init_git_repo(at_init=True)\n            # In case of pull, we need to make sure every process has the latest.\n            if is_torch_tpu_available():\n                xm.rendezvous(\"init git repo\")\n            elif args.parallel_mode == ParallelMode.DISTRIBUTED:\n                dist.barrier()\n\n        if self.args.should_save:\n            os.makedirs(self.args.output_dir, exist_ok=True)\n\n        if not callable(self.data_collator) and callable(getattr(self.data_collator, \"collate_batch\", None)):\n            raise ValueError(\"The `data_collator` should be a simple callable (function, class with `__call__`).\")\n\n        if args.max_steps > 0:\n            logger.info(\"max_steps is given, it will override any value given in num_train_epochs\")\n\n        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:\n            raise ValueError(\n                \"The train_dataset does not implement __len__, max_steps has to be specified. \"\n                \"The number of steps needs to be known in advance for the learning rate scheduler.\"\n            )\n\n        if (\n            train_dataset is not None\n            and isinstance(train_dataset, torch.utils.data.IterableDataset)\n            and args.group_by_length\n        ):\n            raise ValueError(\"the `--group_by_length` option is only available for `Dataset`, not `IterableDataset\")\n\n        self._signature_columns = None\n\n        # Mixed precision setup\n        self.use_apex = False\n        self.use_cuda_amp = False\n        self.use_cpu_amp = False\n\n        # Mixed precision setup for SageMaker Model Parallel\n        if is_sagemaker_mp_enabled():\n            # BF16 + model parallelism in SageMaker: currently not supported, raise an error\n            if args.bf16:\n                raise ValueError(\"SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead \")\n\n            if IS_SAGEMAKER_MP_POST_1_10:\n                # When there's mismatch between SMP config and trainer argument, use SMP config as truth\n                if args.fp16 != smp.state.cfg.fp16:\n                    logger.warning(\n                        f\"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},\"\n                        f\"but FP16 provided in trainer argument is {args.fp16},\"\n                        f\"setting to {smp.state.cfg.fp16}\"\n                    )\n                    args.fp16 = smp.state.cfg.fp16\n            else:\n                # smp < 1.10 does not support fp16 in trainer.\n                if hasattr(smp.state.cfg, \"fp16\"):\n                    logger.warning(\n                        f\"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, \"\n                        \"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer.\"\n                    )\n\n        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:\n            if args.half_precision_backend == \"auto\":\n                if args.device == torch.device(\"cpu\"):\n                    if args.fp16:\n                        raise ValueError(\"Tried to use `fp16` but it is not supported on cpu\")\n                    else:\n                        args.half_precision_backend = \"cpu_amp\"\n                else:\n                    args.half_precision_backend = \"cuda_amp\"\n\n            logger.info(f\"Using {args.half_precision_backend} half precision backend\")\n\n        self.do_grad_scaling = False\n        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):\n            # deepspeed and SageMaker Model Parallel manage their own half precision\n            if self.sharded_ddp is not None:\n                if args.half_precision_backend == \"cuda_amp\":\n                    self.use_cuda_amp = True\n                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16\n                    #  bf16 does not need grad scaling\n                    self.do_grad_scaling = self.amp_dtype == torch.float16\n                    if self.do_grad_scaling:\n                        if self.sharded_ddp is not None:\n                            self.scaler = ShardedGradScaler()\n                        elif self.fsdp is not None:\n                            from torch.distributed.fsdp.sharded_grad_scaler import (\n                                ShardedGradScaler as FSDPShardedGradScaler,\n                            )\n\n                            self.scaler = FSDPShardedGradScaler()\n                        elif is_torch_tpu_available():\n                            from torch_xla.amp import GradScaler\n\n                            self.scaler = GradScaler()\n                        else:\n                            self.scaler = torch.cuda.amp.GradScaler()\n                elif args.half_precision_backend == \"cpu_amp\":\n                    self.use_cpu_amp = True\n                    self.amp_dtype = torch.bfloat16\n            elif args.half_precision_backend == \"apex\":\n                if not is_apex_available():\n                    raise ImportError(\n                        \"Using FP16 with APEX but APEX is not installed, please refer to\"\n                        \" https://www.github.com/nvidia/apex.\"\n                    )\n                self.use_apex = True\n\n        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.\n        if (\n            is_sagemaker_mp_enabled()\n            and self.use_cuda_amp\n            and args.max_grad_norm is not None\n            and args.max_grad_norm > 0\n        ):\n            raise ValueError(\n                \"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass \"\n                \"along 'max_grad_norm': 0 in your hyperparameters.\"\n            )\n\n        # Label smoothing\n        if self.args.label_smoothing_factor != 0:\n            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)\n        else:\n            self.label_smoother = None\n\n        self.state = TrainerState(\n            is_local_process_zero=self.is_local_process_zero(),\n            is_world_process_zero=self.is_world_process_zero(),\n        )\n\n        self.control = TrainerControl()\n        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then\n        # returned to 0 every time flos need to be logged\n        self.current_flos = 0\n        self.hp_search_backend = None\n        self.use_tune_checkpoints = False\n        default_label_names = find_labels(self.model.__class__)\n        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names\n        self.can_return_loss = can_return_loss(self.model.__class__)\n        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)\n\n        # Internal variables to help with automatic batch size reduction\n        self._train_batch_size = args.train_batch_size\n        self._created_lr_scheduler = False\n\n        # very last\n        self._memory_tracker.stop_and_update_metrics()\n\n        # torch.compile\n        if args.torch_compile and not is_torch_compile_available():\n            raise RuntimeError(\"Using torch.compile requires PyTorch 2.0 or higher.\")\n\n    def add_callback(self, callback):\n        \"\"\"\n        Add a callback to the current list of [`~transformer.TrainerCallback`].\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will instantiate a member of that class.\n        \"\"\"\n        self.callback_handler.add_callback(callback)\n\n    def pop_callback(self, callback):\n        \"\"\"\n        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.\n\n        If the callback is not found, returns `None` (and no error is raised).\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will pop the first member of that class found in the list of callbacks.\n\n        Returns:\n            [`~transformer.TrainerCallback`]: The callback removed, if found.\n        \"\"\"\n        return self.callback_handler.pop_callback(callback)\n\n    def remove_callback(self, callback):\n        \"\"\"\n        Remove a callback from the current list of [`~transformer.TrainerCallback`].\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will remove the first member of that class found in the list of callbacks.\n        \"\"\"\n        self.callback_handler.remove_callback(callback)\n\n    def _move_model_to_device(self, model, device):\n        model = model.to(device)\n        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.\n        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n\n    def _set_signature_columns_if_needed(self):\n        if self._signature_columns is None:\n            # Inspect model forward signature to keep only the arguments it accepts.\n            signature = inspect.signature(self.model.forward)\n            self._signature_columns = list(signature.parameters.keys())\n            # Labels may be named label or label_ids, the default data collator handles that.\n            self._signature_columns += list(set([\"label\", \"label_ids\"] + self.label_names))\n\n    def _remove_unused_columns(self, dataset: \"datasets.Dataset\", description: Optional[str] = None):\n        if not self.args.remove_unused_columns:\n            return dataset\n        self._set_signature_columns_if_needed()\n        signature_columns = self._signature_columns\n\n        ignored_columns = list(set(dataset.column_names) - set(signature_columns))\n        if len(ignored_columns) > 0:\n            dset_description = \"\" if description is None else f\"in the {description} set\"\n            logger.info(\n                f\"The following columns {dset_description} don't have a corresponding argument in \"\n                f\"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}.\"\n                f\" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, \"\n                \" you can safely ignore this message.\"\n            )\n\n        columns = [k for k in signature_columns if k in dataset.column_names]\n\n        if version.parse(datasets.__version__) < version.parse(\"1.4.0\"):\n            dataset.set_format(\n                type=dataset.format[\"type\"], columns=columns, format_kwargs=dataset.format[\"format_kwargs\"]\n            )\n            return dataset\n        else:\n            return dataset.remove_columns(ignored_columns)\n\n    def _get_collator_with_removed_columns(\n        self, data_collator: Callable, description: Optional[str] = None\n    ) -> Callable:\n        \"\"\"Wrap the data collator in a callable removing unused columns.\"\"\"\n        if not self.args.remove_unused_columns:\n            return data_collator\n        self._set_signature_columns_if_needed()\n        signature_columns = self._signature_columns\n\n        remove_columns_collator = RemoveColumnsCollator(\n            data_collator=data_collator,\n            signature_columns=signature_columns,\n            logger=logger,\n            description=description,\n            model_name=self.model.__class__.__name__,\n        )\n        return remove_columns_collator\n\n    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:\n        if self.train_dataset is None or not has_length(self.train_dataset):\n            return None\n\n        # Build the sampler.\n        if self.args.group_by_length:\n            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):\n                lengths = (\n                    self.train_dataset[self.args.length_column_name]\n                    if self.args.length_column_name in self.train_dataset.column_names\n                    else None\n                )\n            else:\n                lengths = None\n            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None\n            return LengthGroupedSampler(\n                self.args.train_batch_size * self.args.gradient_accumulation_steps,\n                dataset=self.train_dataset,\n                lengths=lengths,\n                model_input_name=model_input_name,\n            )\n\n        else:\n            return RandomSampler(self.train_dataset)\n\n    def get_train_dataloader(self) -> DataLoader:\n        \"\"\"\n        Returns the training [`~torch.utils.data.DataLoader`].\n\n        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed\n        training if necessary) otherwise.\n\n        Subclass and override this method if you want to inject some custom behavior.\n        \"\"\"\n        if self.train_dataset is None:\n            raise ValueError(\"Trainer: training requires a train_dataset.\")\n\n        train_dataset = self.train_dataset\n\n        data_collator = self.data_collator\n        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):\n            train_dataset = self._remove_unused_columns(train_dataset, description=\"training\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"training\")\n\n        dataloader_params = {\n            \"batch_size\": self._train_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(train_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_train_sampler()\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n            dataloader_params[\"worker_init_fn\"] = seed_worker\n\n        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))\n\n    def get_train_dataloaderd2(self) -> DataLoader:\n        return build_train_dataloader(self.cfg,tokenizer=self.data_loader_args[0],data_args=self.data_loader_args[1],preprocess=self.data_loader_args[2] )\n\n    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:\n        # Deprecated code\n        if self.args.use_legacy_prediction_loop:\n            if is_torch_tpu_available():\n                return SequentialDistributedSampler(\n                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()\n                )\n            elif is_sagemaker_mp_enabled():\n                return SequentialDistributedSampler(\n                    eval_dataset,\n                    num_replicas=smp.dp_size(),\n                    rank=smp.dp_rank(),\n                    batch_size=self.args.per_device_eval_batch_size,\n                )\n            else:\n                return SequentialSampler(eval_dataset)\n\n        if self.args.world_size <= 1:\n            return SequentialSampler(eval_dataset)\n        else:\n            return None\n\n    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:\n        \"\"\"\n        Returns the evaluation [`~torch.utils.data.DataLoader`].\n\n        Subclass and override this method if you want to inject some custom behavior.\n\n        Args:\n            eval_dataset (`torch.utils.data.Dataset`, *optional*):\n                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted\n                by the `model.forward()` method are automatically removed. It must implement `__len__`.\n        \"\"\"\n        if eval_dataset is None and self.eval_dataset is None:\n            raise ValueError(\"Trainer: evaluation requires an eval_dataset.\")\n        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset\n        data_collator = self.data_collator\n\n        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):\n            eval_dataset = self._remove_unused_columns(eval_dataset, description=\"evaluation\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"evaluation\")\n\n        dataloader_params = {\n            \"batch_size\": self.args.eval_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_eval_sampler(eval_dataset)\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))\n\n    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:\n        \"\"\"\n        Returns the test [`~torch.utils.data.DataLoader`].\n\n        Subclass and override this method if you want to inject some custom behavior.\n\n        Args:\n            test_dataset (`torch.utils.data.Dataset`, *optional*):\n                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the\n                `model.forward()` method are automatically removed. It must implement `__len__`.\n        \"\"\"\n        data_collator = self.data_collator\n\n        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):\n            test_dataset = self._remove_unused_columns(test_dataset, description=\"test\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"test\")\n\n        dataloader_params = {\n            \"batch_size\": self.args.eval_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(test_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_eval_sampler(test_dataset)\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n        # We use the same batch_size as for eval.\n        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))\n\n    def create_optimizer_and_scheduler(self, num_training_steps: int):\n        \"\"\"\n        Setup the optimizer and the learning rate scheduler.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or\n        `create_scheduler`) in a subclass.\n        \"\"\"\n        self.create_optimizer()\n        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:\n            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer\n            optimizer = self.optimizer.optimizer\n        else:\n            optimizer = self.optimizer\n        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model\n\n        if self.optimizer is None:\n            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            # optimizer_grouped_parameters = [\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": self.args.weight_decay,\n            #     },\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": 0.0,\n            #     },\n            # ]\n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n\n            def match_name_keywords(n, name_keywords):\n                out = False\n                for b in name_keywords:\n                    if b in n:\n                        out = True\n                        break\n                return out\n\n            lr_backbone_names=['backbone']\n            lr_linear_proj_names=['reference_points', 'sampling_offsets']\n            seg_model_names=['seg_model']\n            optimizer_grouped_parameters = [\n                {\n                    \"params\":\n                        [p for n, p in opt_model.named_parameters()\n                         if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names))\n                         and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names))\n                         and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr'],\n                },\n                {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names) and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr']*0.1,\n                },\n                {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names) and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr']*0.1,\n                },\n\n            ]\n            if not getattr(self.args, 'tune_mm_mlp_adapter', False):\n                optimizer_grouped_parameters[0] = {\n                        \"params\":\n                            [p for n, p in opt_model.named_parameters()\n                             if\n                             not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names))\n                             and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n,\n                                                                                                           seg_model_names))\n                             and match_name_keywords(n,seg_model_names)\n                             and p.requires_grad],\n                        \"lr\": optimizer_kwargs['lr'],\n                    }\n                llm_dict= {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if n.startswith('model.') and p.requires_grad],\n                    \"lr\": 2e-5,\n                }\n                optimizer_grouped_parameters.append(llm_dict)\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                self.optimizer = OSS(\n                    params=optimizer_grouped_parameters,\n                    optim=optimizer_cls,\n                    **optimizer_kwargs,\n                )\n            else:\n                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n                if optimizer_cls.__name__ == \"Adam8bit\":\n                    import bitsandbytes\n\n                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()\n\n                    skipped = 0\n                    for module in opt_model.modules():\n                        if isinstance(module, nn.Embedding):\n                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())\n                            logger.info(f\"skipped {module}: {skipped/2**20}M params\")\n                            manager.register_module_override(module, \"weight\", {\"optim_bits\": 32})\n                            logger.debug(f\"bitsandbytes: will optimize {module} in fp32\")\n                    logger.info(f\"skipped: {skipped/2**20}M params\")\n\n        if is_sagemaker_mp_enabled():\n            self.optimizer = smp.DistributedOptimizer(self.optimizer)\n\n        return self.optimizer\n\n    @staticmethod\n    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:\n        \"\"\"\n        Returns the optimizer class and optimizer parameters based on the training arguments.\n\n        Args:\n            args (`transformers.training_args.TrainingArguments`):\n                The training arguments for the training session.\n\n        \"\"\"\n\n        # parse args.optim_args\n        optim_args = {}\n        if args.optim_args:\n            for mapping in args.optim_args.replace(\" \", \"\").split(\",\"):\n                key, value = mapping.split(\"=\")\n                optim_args[key] = value\n\n        optimizer_kwargs = {\"lr\": args.learning_rate}\n\n        adam_kwargs = {\n            \"betas\": (args.adam_beta1, args.adam_beta2),\n            \"eps\": args.adam_epsilon,\n        }\n        if args.optim == OptimizerNames.ADAFACTOR:\n            optimizer_cls = Adafactor\n            optimizer_kwargs.update({\"scale_parameter\": False, \"relative_step\": False})\n        elif args.optim == OptimizerNames.ADAMW_HF:\n            from .optimization import AdamW\n\n            optimizer_cls = AdamW\n            optimizer_kwargs.update(adam_kwargs)\n        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:\n            from torch.optim import AdamW\n\n            optimizer_cls = AdamW\n            optimizer_kwargs.update(adam_kwargs)\n            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:\n                optimizer_kwargs.update({\"fused\": True})\n        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:\n            try:\n                from torch_xla.amp.syncfree import AdamW\n\n                optimizer_cls = AdamW\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer failed to import syncfree AdamW from torch_xla.\")\n        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:\n            try:\n                from apex.optimizers import FusedAdam\n\n                optimizer_cls = FusedAdam\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate apex FusedAdam but apex is not installed!\")\n        elif args.optim in [\n            OptimizerNames.ADAMW_BNB,\n            OptimizerNames.ADAMW_8BIT,\n            OptimizerNames.PAGED_ADAMW,\n            OptimizerNames.PAGED_ADAMW_8BIT,\n            OptimizerNames.LION,\n            OptimizerNames.LION_8BIT,\n            OptimizerNames.PAGED_LION,\n            OptimizerNames.PAGED_LION_8BIT,\n        ]:\n            try:\n                from bitsandbytes.optim import AdamW, Lion\n\n                is_paged = False\n                optim_bits = 32\n                optimizer_cls = None\n                additional_optim_kwargs = adam_kwargs\n                if \"paged\" in args.optim:\n                    is_paged = True\n                if \"8bit\" in args.optim:\n                    optim_bits = 8\n                if \"adam\" in args.optim:\n                    optimizer_cls = AdamW\n                elif \"lion\" in args.optim:\n                    optimizer_cls = Lion\n                    additional_optim_kwargs = {\"betas\": (args.adam_beta1, args.adam_beta2)}\n\n                bnb_kwargs = {\"is_paged\": is_paged, \"optim_bits\": optim_bits}\n                optimizer_kwargs.update(additional_optim_kwargs)\n                optimizer_kwargs.update(bnb_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate bnb optimizer but bnb is not installed!\")\n        elif args.optim == OptimizerNames.ADAMW_BNB:\n            try:\n                from bitsandbytes.optim import Adam8bit\n\n                optimizer_cls = Adam8bit\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate bnb Adam8bit but bnb is not installed!\")\n        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:\n            try:\n                from torchdistx.optimizers import AnyPrecisionAdamW\n\n                optimizer_cls = AnyPrecisionAdamW\n                optimizer_kwargs.update(adam_kwargs)\n\n                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.\n                optimizer_kwargs.update(\n                    {\n                        \"use_kahan_summation\": strtobool(optim_args.get(\"use_kahan_summation\", \"False\")),\n                        \"momentum_dtype\": getattr(torch, optim_args.get(\"momentum_dtype\", \"float32\")),\n                        \"variance_dtype\": getattr(torch, optim_args.get(\"variance_dtype\", \"float32\")),\n                        \"compensation_buffer_dtype\": getattr(\n                            torch, optim_args.get(\"compensation_buffer_dtype\", \"bfloat16\")\n                        ),\n                    }\n                )\n            except ImportError:\n                raise ValueError(\"Please install https://github.com/pytorch/torchdistx\")\n        elif args.optim == OptimizerNames.SGD:\n            optimizer_cls = torch.optim.SGD\n        elif args.optim == OptimizerNames.ADAGRAD:\n            optimizer_cls = torch.optim.Adagrad\n        else:\n            raise ValueError(f\"Trainer cannot instantiate unsupported optimizer: {args.optim}\")\n        return optimizer_cls, optimizer_kwargs\n\n    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):\n        \"\"\"\n        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or\n        passed as an argument.\n\n        Args:\n            num_training_steps (int): The number of training steps to do.\n        \"\"\"\n        if self.lr_scheduler is None:\n            self.lr_scheduler = get_scheduler(\n                self.args.lr_scheduler_type,\n                optimizer=self.optimizer if optimizer is None else optimizer,\n                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                num_training_steps=num_training_steps,\n            )\n            self._created_lr_scheduler = True\n        return self.lr_scheduler\n\n    def num_examples(self, dataloader: DataLoader) -> int:\n        \"\"\"\n        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When\n        dataloader.dataset does not exist or has no length, estimates as best it can\n        \"\"\"\n        try:\n            dataset = dataloader.dataset\n            # Special case for IterableDatasetShard, we need to dig deeper\n            if isinstance(dataset, IterableDatasetShard):\n                return len(dataloader.dataset.dataset)\n            return len(dataloader.dataset)\n        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader\n            return len(dataloader) * self.args.per_device_train_batch_size\n\n    def _hp_search_setup(self, trial: Union[\"optuna.Trial\", Dict[str, Any]]):\n        \"\"\"HP search setup code\"\"\"\n        self._trial = trial\n\n        if self.hp_search_backend is None or trial is None:\n            return\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            params = self.hp_space(trial)\n        elif self.hp_search_backend == HPSearchBackend.RAY:\n            params = trial\n            params.pop(\"wandb\", None)\n        elif self.hp_search_backend == HPSearchBackend.SIGOPT:\n            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}\n        elif self.hp_search_backend == HPSearchBackend.WANDB:\n            params = trial\n\n        for key, value in params.items():\n            if not hasattr(self.args, key):\n                logger.warning(\n                    f\"Trying to set {key} in the hyperparameter search but there is no corresponding field in\"\n                    \" `TrainingArguments`.\"\n                )\n                continue\n            old_attr = getattr(self.args, key, None)\n            # Casting value to the proper type\n            if old_attr is not None:\n                value = type(old_attr)(value)\n            setattr(self.args, key, value)\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            logger.info(f\"Trial: {trial.params}\")\n        if self.hp_search_backend == HPSearchBackend.SIGOPT:\n            logger.info(f\"SigOpt Assignments: {trial.assignments}\")\n        if self.hp_search_backend == HPSearchBackend.WANDB:\n            logger.info(f\"W&B Sweep parameters: {trial}\")\n        if self.is_deepspeed_enabled:\n            if self.args.deepspeed is None:\n                raise ValueError(\"For sweeps with deepspeed, `args.deepspeed` must be set\")\n            # Rebuild the deepspeed config to reflect the updated training parameters\n            from accelerate.utils import DeepSpeedPlugin\n\n            from transformers.deepspeed import HfTrainerDeepSpeedConfig\n\n            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)\n            self.args.hf_deepspeed_config.trainer_config_process(self.args)\n            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)\n        self.create_accelerator_and_postprocess()\n\n    def _report_to_hp_search(self, trial: Union[\"optuna.Trial\", Dict[str, Any]], step: int, metrics: Dict[str, float]):\n        if self.hp_search_backend is None or trial is None:\n            return\n        self.objective = self.compute_objective(metrics.copy())\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            import optuna\n\n            trial.report(self.objective, step)\n            if trial.should_prune():\n                self.callback_handler.on_train_end(self.args, self.state, self.control)\n                raise optuna.TrialPruned()\n        elif self.hp_search_backend == HPSearchBackend.RAY:\n            from ray import tune\n\n            if self.control.should_save:\n                self._tune_save_checkpoint()\n            tune.report(objective=self.objective, **metrics)\n\n    def _tune_save_checkpoint(self):\n        from ray import tune\n\n        if not self.use_tune_checkpoints:\n            return\n        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:\n            output_dir = os.path.join(checkpoint_dir, f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\")\n            self.save_model(output_dir, _internal_call=True)\n            if self.args.should_save:\n                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))\n                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n\n    def call_model_init(self, trial=None):\n        model_init_argcount = number_of_arguments(self.model_init)\n        if model_init_argcount == 0:\n            model = self.model_init()\n        elif model_init_argcount == 1:\n            model = self.model_init(trial)\n        else:\n            raise RuntimeError(\"model_init should have 0 or 1 argument.\")\n\n        if model is None:\n            raise RuntimeError(\"model_init should not return None.\")\n\n        return model\n\n    def torch_jit_model_eval(self, model, dataloader, training=False):\n        if not training:\n            if dataloader is None:\n                logger.warning(\"failed to use PyTorch jit mode due to current dataloader is none.\")\n                return model\n            example_batch = next(iter(dataloader))\n            example_batch = self._prepare_inputs(example_batch)\n            try:\n                jit_model = copy.copy(model)\n                jit_model.eval()\n                original_forward = jit_model.__dict__.pop(\"_original_forward\", None)\n                # remove mixed precision hooks from the model\n                if original_forward:\n                    jit_model.forward = original_forward\n                with self.accelerator.autocast(cache_enabled=False), torch.no_grad():\n                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse(\"2.0.0\"):\n                        if isinstance(example_batch, dict):\n                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)\n                        else:\n                            jit_model = torch.jit.trace(\n                                jit_model,\n                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},\n                                strict=False,\n                            )\n                    else:\n                        jit_inputs = []\n                        for key in example_batch:\n                            example_tensor = torch.ones_like(example_batch[key])\n                            jit_inputs.append(example_tensor)\n                        jit_inputs = tuple(jit_inputs)\n                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)\n                jit_model = torch.jit.freeze(jit_model)\n                with torch.no_grad():\n                    jit_model(**example_batch)\n                    jit_model(**example_batch)\n                model = jit_model\n                self.use_cpu_amp = False\n                self.use_cuda_amp = False\n            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:\n                logger.warning(f\"failed to use PyTorch jit mode due to: {e}.\")\n\n        return model\n\n    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):\n        if not is_ipex_available():\n            raise ImportError(\n                \"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer\"\n                \" to https://github.com/intel/intel-extension-for-pytorch.\"\n            )\n\n        import intel_extension_for_pytorch as ipex\n\n        if not training:\n            model.eval()\n            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype\n            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings\n            model = ipex.optimize(model, dtype=dtype, level=\"O1\", conv_bn_folding=False, inplace=not self.is_in_train)\n        else:\n            if not model.training:\n                model.train()\n            model, self.optimizer = ipex.optimize(\n                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level=\"O1\"\n            )\n\n        return model\n\n    def _wrap_model(self, model, training=True, dataloader=None):\n        if self.args.use_ipex:\n            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32\n            model = self.ipex_optimize_model(model, training, dtype=dtype)\n\n        if is_sagemaker_mp_enabled():\n            # Wrapping the base model twice in a DistributedModel will raise an error.\n            if isinstance(self.model_wrapped, smp.model.DistributedModel):\n                return self.model_wrapped\n            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)\n\n        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again\n        if unwrap_model(model) is not model:\n            return model\n\n        # Mixed precision training with apex (torch < 1.6)\n        if self.use_apex and training:\n            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)\n\n        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP\n        if self.args.n_gpu > 1 and not getattr(model, \"is_loaded_in_8bit\", False):\n            model = nn.DataParallel(model)\n\n        if self.args.jit_mode_eval:\n            start_time = time.time()\n            model = self.torch_jit_model_eval(model, dataloader, training)\n            self.jit_compilation_time = round(time.time() - start_time, 4)\n\n        # Note: in torch.distributed mode, there's no point in wrapping the model\n        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.\n        if not training:\n            return model\n\n        # Distributed training (should be after apex fp16 initialization)\n        if self.sharded_ddp is not None:\n            # Sharded DDP!\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                model = ShardedDDP(model, self.optimizer)\n            else:\n                mixed_precision = self.args.fp16 or self.args.bf16\n                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp\n                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3\n                # XXX: Breaking the self.model convention but I see no way around it for now.\n                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:\n                    model = auto_wrap(model)\n                self.model = model = FullyShardedDDP(\n                    model,\n                    mixed_precision=mixed_precision,\n                    reshard_after_forward=zero_3,\n                    cpu_offload=cpu_offload,\n                ).to(self.args.device)\n        # Distributed training using PyTorch FSDP\n        elif self.fsdp is not None and self.args.fsdp_config[\"xla\"]:\n            try:\n                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP\n                from torch_xla.distributed.fsdp import checkpoint_module\n                from torch_xla.distributed.fsdp.wrap import (\n                    size_based_auto_wrap_policy,\n                    transformer_auto_wrap_policy,\n                )\n            except ImportError:\n                raise ImportError(\"Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.\")\n            auto_wrap_policy = None\n            auto_wrapper_callable = None\n            if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                auto_wrap_policy = functools.partial(\n                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                )\n            elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                transformer_cls_to_wrap = set()\n                for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                    transformer_cls = get_module_class_from_name(model, layer_class)\n                    if transformer_cls is None:\n                        raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                    else:\n                        transformer_cls_to_wrap.add(transformer_cls)\n                auto_wrap_policy = functools.partial(\n                    transformer_auto_wrap_policy,\n                    # Transformer layer class to wrap\n                    transformer_layer_cls=transformer_cls_to_wrap,\n                )\n            fsdp_kwargs = self.args.xla_fsdp_config\n            if self.args.fsdp_config[\"xla_fsdp_grad_ckpt\"]:\n                # Apply gradient checkpointing to auto-wrapped sub-modules if specified\n                def auto_wrapper_callable(m, *args, **kwargs):\n                    return FSDP(checkpoint_module(m), *args, **kwargs)\n\n            # Wrap the base model with an outer FSDP wrapper\n            self.model = model = FSDP(\n                model,\n                auto_wrap_policy=auto_wrap_policy,\n                auto_wrapper_callable=auto_wrapper_callable,\n                **fsdp_kwargs,\n            )\n\n            # Patch `xm.optimizer_step` should not reduce gradients in this case,\n            # as FSDP does not need gradient reduction over sharded parameters.\n            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):\n                loss = optimizer.step(**optimizer_args)\n                if barrier:\n                    xm.mark_step()\n                return loss\n\n            xm.optimizer_step = patched_optimizer_step\n        elif is_sagemaker_dp_enabled():\n            model = nn.parallel.DistributedDataParallel(\n                model, device_ids=[int(os.getenv(\"SMDATAPARALLEL_LOCAL_RANK\"))]\n            )\n        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            if is_torch_neuroncore_available():\n                return model\n            kwargs = {}\n            if self.args.ddp_find_unused_parameters is not None:\n                kwargs[\"find_unused_parameters\"] = self.args.ddp_find_unused_parameters\n            elif isinstance(model, PreTrainedModel):\n                # find_unused_parameters breaks checkpointing as per\n                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021\n                kwargs[\"find_unused_parameters\"] = not model.is_gradient_checkpointing\n            else:\n                kwargs[\"find_unused_parameters\"] = True\n\n            if self.args.ddp_bucket_cap_mb is not None:\n                kwargs[\"bucket_cap_mb\"] = self.args.ddp_bucket_cap_mb\n\n            if self.args.ddp_broadcast_buffers is not None:\n                kwargs[\"broadcast_buffers\"] = self.args.ddp_broadcast_buffers\n\n            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)\n\n        return model\n\n    def train(\n        self,\n        resume_from_checkpoint: Optional[Union[str, bool]] = None,\n        trial: Union[\"optuna.Trial\", Dict[str, Any]] = None,\n        ignore_keys_for_eval: Optional[List[str]] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Main training entry point.\n\n        Args:\n            resume_from_checkpoint (`str` or `bool`, *optional*):\n                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a\n                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance\n                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.\n            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):\n                The trial run or the hyperparameter dictionary for hyperparameter search.\n            ignore_keys_for_eval (`List[str]`, *optional*)\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions for evaluation during the training.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments used to hide deprecated arguments\n        \"\"\"\n        if resume_from_checkpoint is False:\n            resume_from_checkpoint = None\n\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        args = self.args\n\n        self.is_in_train = True\n\n        # do_train is not a reliable argument, as it might not be set and .train() still called, so\n        # the following is a workaround:\n        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:\n            self._move_model_to_device(self.model, args.device)\n\n        if \"model_path\" in kwargs:\n            resume_from_checkpoint = kwargs.pop(\"model_path\")\n            warnings.warn(\n                \"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` \"\n                \"instead.\",\n                FutureWarning,\n            )\n        if len(kwargs) > 0:\n            raise TypeError(f\"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.\")\n        # This might change the seed so needs to run first.\n        self._hp_search_setup(trial)\n        self._train_batch_size = self.args.train_batch_size\n\n        # Model re-init\n        model_reloaded = False\n        if self.model_init is not None:\n            # Seed must be set before instantiating the model when using model_init.\n            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)\n            self.model = self.call_model_init(trial)\n            model_reloaded = True\n            # Reinitializes optimizer and scheduler\n            self.optimizer, self.lr_scheduler = None, None\n\n        # Load potential model checkpoint\n        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:\n            resume_from_checkpoint = get_last_checkpoint(args.output_dir)\n            if resume_from_checkpoint is None:\n                raise ValueError(f\"No valid checkpoint found in output directory ({args.output_dir})\")\n        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:\n            self._load_from_checkpoint(resume_from_checkpoint)\n        # If model was re-initialized, put it on the right device and update self.model_wrapped\n        if model_reloaded:\n            if self.place_model_on_device:\n                self._move_model_to_device(self.model, args.device)\n            self.model_wrapped = self.model\n\n        inner_training_loop = find_executable_batch_size(\n            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size\n        )\n        return inner_training_loop(\n            args=args,\n            resume_from_checkpoint=resume_from_checkpoint,\n            trial=trial,\n            ignore_keys_for_eval=ignore_keys_for_eval,\n        )\n\n    def _inner_training_loop(\n        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None\n    ):\n        self.accelerator.free_memory()\n        self._train_batch_size = batch_size\n        logger.debug(f\"Currently training with a batch size of: {self._train_batch_size}\")\n        # Data loader and number of training steps\n        train_dataloader = self.get_train_dataloaderd2()\n\n        # Setting up training control variables:\n        # number of training epochs: num_train_epochs\n        # number of training steps per epoch: num_update_steps_per_epoch\n        # total number of training steps to execute: max_steps\n        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size\n        len_dataloader = None\n        if args.max_steps<0:\n            args.max_steps=100\n        if has_length(train_dataloader):\n            len_dataloader = len(train_dataloader)\n            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps\n            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n            num_examples = self.num_examples(train_dataloader)\n            if args.max_steps > 0:\n                max_steps = args.max_steps\n                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(\n                    args.max_steps % num_update_steps_per_epoch > 0\n                )\n                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's\n                # the best we can do.\n                num_train_samples = args.max_steps * total_train_batch_size\n            else:\n                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)\n                num_train_epochs = math.ceil(args.num_train_epochs)\n                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs\n        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size\n            max_steps = args.max_steps\n            # Setting a very large number of epochs so we go as many times as necessary over the iterator.\n            num_train_epochs = sys.maxsize\n            num_update_steps_per_epoch = max_steps\n            num_examples = total_train_batch_size * args.max_steps\n            num_train_samples = args.max_steps * total_train_batch_size\n        else:\n            raise ValueError(\n                \"args.max_steps must be set to a positive value if dataloader does not have a length, was\"\n                f\" {args.max_steps}\"\n            )\n\n        # Compute absolute values for logging, eval, and save if given as ratio\n        if args.logging_steps and args.logging_steps < 1:\n            args.logging_steps = math.ceil(max_steps * args.logging_steps)\n        if args.eval_steps and args.eval_steps < 1:\n            args.eval_steps = math.ceil(max_steps * args.eval_steps)\n        if args.save_steps and args.save_steps < 1:\n            args.save_steps = math.ceil(max_steps * args.save_steps)\n\n        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:\n            if self.args.n_gpu > 1:\n                # nn.DataParallel(model) replicates the model, creating new variables and module\n                # references registered here no longer work on other gpus, breaking the module\n                raise ValueError(\n                    \"Currently --debug underflow_overflow is not supported under DP. Please use DDP\"\n                    \" (torch.distributed.launch).\"\n                )\n            else:\n                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa\n\n        delay_optimizer_creation = (\n            self.sharded_ddp is not None\n            and self.sharded_ddp != ShardedDDPOption.SIMPLE\n            or is_sagemaker_mp_enabled()\n            or self.fsdp is not None\n        )\n\n        # We need to reset the scheduler, as its parameters may be different on subsequent calls\n        if self._created_lr_scheduler:\n            self.lr_scheduler = None\n            self._created_lr_scheduler = False\n\n        if self.is_deepspeed_enabled:\n            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)\n\n        if not delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        self.state = TrainerState()\n        self.state.is_hyper_param_search = trial is not None\n\n        # Activate gradient checkpointing if needed\n        if args.gradient_checkpointing:\n            self.model.gradient_checkpointing_enable()\n\n        model = self._wrap_model(self.model_wrapped)\n\n        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:\n            self._load_from_checkpoint(resume_from_checkpoint, model)\n\n        # as the model is wrapped, don't use `accelerator.prepare`\n        # this is for unhandled cases such as\n        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX\n        use_accelerator_prepare = True if model is self.model else False\n\n        if delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        # prepare using `accelerator` prepare\n        if use_accelerator_prepare:\n            self.model.train()\n            if hasattr(self.lr_scheduler, \"step\"):\n                if self.use_apex:\n                    model = self.accelerator.prepare(self.model)\n                else:\n                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)\n            else:\n                # to handle cases wherein we pass \"DummyScheduler\" such as when it is specified in DeepSpeed config.\n                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(\n                    self.model, self.optimizer, self.lr_scheduler\n                )\n\n        if self.is_fsdp_enabled:\n            self.model = model\n\n        # for the rest of this function `model` is the outside model, whether it was wrapped or not\n        if model is not self.model:\n            self.model_wrapped = model\n\n        # backward compatibility\n        if self.is_deepspeed_enabled:\n            self.deepspeed = self.model_wrapped\n\n        # deepspeed ckpt loading\n        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:\n            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint, load_optimizer_states=self.args.load_optimizer_states, load_lr_scheduler_states=self.args.load_lr_scheduler_states)\n        # Check if saved optimizer or scheduler states exist\n        self._load_optimizer_and_scheduler(resume_from_checkpoint)\n        # important: at this point:\n        # self.model         is the Transformers Model\n        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.\n\n        # Train!\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {num_examples:,}\")\n        logger.info(f\"  Num Epochs = {num_train_epochs:,}\")\n        logger.info(f\"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}\")\n        if self.args.per_device_train_batch_size != self._train_batch_size:\n            logger.info(f\"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}\")\n        logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}\")\n        logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n        logger.info(f\"  Total optimization steps = {max_steps:,}\")\n        logger.info(f\"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}\")\n\n        self.state.epoch = 0\n        start_time = time.time()\n        epochs_trained = 0\n        steps_trained_in_current_epoch = 0\n        steps_trained_progress_bar = None\n        # Check if continuing training from a checkpoint\n        if resume_from_checkpoint is not None and os.path.isfile(\n            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)\n        ):\n            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))\n            epochs_trained = self.state.global_step // num_update_steps_per_epoch\n            if not args.ignore_data_skip:\n                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)\n                steps_trained_in_current_epoch *= args.gradient_accumulation_steps\n            else:\n                steps_trained_in_current_epoch = 0\n\n            logger.info(\"  Continuing training from checkpoint, will skip to saved global_step\")\n            logger.info(f\"  Continuing training from epoch {epochs_trained}\")\n            logger.info(f\"  Continuing training from global step {self.state.global_step}\")\n            if not args.ignore_data_skip:\n                logger.info(\n                    f\"  Will skip the first {epochs_trained} epochs then the first\"\n                    f\" {steps_trained_in_current_epoch} batches in the first epoch.\"\n                )\n        # Update the references\n        self.callback_handler.model = self.model\n        self.callback_handler.optimizer = self.optimizer\n        self.callback_handler.lr_scheduler = self.lr_scheduler\n        self.callback_handler.train_dataloader = train_dataloader\n        if self.hp_name is not None and self._trial is not None:\n            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial\n            # parameter to Train when using DDP.\n            self.state.trial_name = self.hp_name(self._trial)\n        if trial is not None:\n            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial\n            self.state.trial_params = hp_params(assignments)\n        else:\n            self.state.trial_params = None\n        # This should be the same if the state has been saved but in case the training arguments changed, it's safer\n        # to set this after the load.\n        self.state.max_steps = max_steps\n        self.state.num_train_epochs = num_train_epochs\n        self.state.is_local_process_zero = self.is_local_process_zero()\n        self.state.is_world_process_zero = self.is_world_process_zero()\n\n        # tr_loss is a tensor to avoid synchronization of TPUs through .item()\n        tr_loss_ = torch.tensor(0.0).to(args.device)\n        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses\n        self._total_loss_scalar = 0.0\n        self._globalstep_last_logged = self.state.global_step\n        model.zero_grad()\n\n        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)\n\n        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.\n        if not args.ignore_data_skip:\n            for epoch in range(epochs_trained):\n                for _ in train_dataloader:\n                    break\n\n        total_batched_samples = 0\n        tr_loss = dict()\n        for epoch in range(epochs_trained, num_train_epochs):\n            epoch_iterator = train_dataloader\n\n            # Reset the past mems state at the beginning of each epoch if necessary.\n            if args.past_index >= 0:\n                self._past = None\n\n            steps_in_epoch = (\n                len(epoch_iterator)\n                if len_dataloader is not None\n                else args.max_steps * args.gradient_accumulation_steps\n            )\n            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)\n\n            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:\n                self._load_rng_state(resume_from_checkpoint)\n\n            rng_to_sync = False\n            steps_skipped = 0\n            # if steps_trained_in_current_epoch > 0:\n            #     epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)\n            #     steps_skipped = steps_trained_in_current_epoch\n            #     steps_trained_in_current_epoch = 0\n            #     rng_to_sync = True\n\n            step = -1\n            for step, inputs in enumerate(epoch_iterator):\n                total_batched_samples += 1\n                if rng_to_sync:\n                    self._load_rng_state(resume_from_checkpoint)\n                    rng_to_sync = False\n\n                # Skip past any already trained steps if resuming training\n                if steps_trained_in_current_epoch > 0:\n                    steps_trained_in_current_epoch =0\n                    if steps_trained_progress_bar is not None:\n                        steps_trained_progress_bar.update(steps_trained_in_current_epoch)\n                    if steps_trained_in_current_epoch == 0:\n                        self._load_rng_state(resume_from_checkpoint)\n                    continue\n                elif steps_trained_progress_bar is not None:\n                    steps_trained_progress_bar.close()\n                    steps_trained_progress_bar = None\n                if step % args.gradient_accumulation_steps == 0:\n                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)\n                with self.accelerator.accumulate(model):\n                    tr_loss_step = self.training_step(model, inputs)\n                if len(tr_loss)==0:\n                    tr_loss={k:tr_loss_.clone() for k in tr_loss_step.keys()}\n                for k, loss in tr_loss.items():\n                    if (\n                        args.logging_nan_inf_filter\n                        and not is_torch_tpu_available()\n                        and (torch.isnan(tr_loss_step[k]) or torch.isinf(tr_loss_step[k]))\n                    ):\n                        # if loss is nan or inf simply add the average of previous logged losses\n                        tr_loss[k] += loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                    else:\n                        tr_loss[k] += tr_loss_step[k]\n\n                # if (\n                #     args.logging_nan_inf_filter\n                #     and not is_torch_tpu_available()\n                #     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))\n                # ):\n                #     # if loss is nan or inf simply add the average of previous logged losses\n                #     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                # else:\n                #     tr_loss += tr_loss_step\n\n                self.current_flos += float(self.floating_point_ops(inputs))\n\n                is_last_step_and_steps_less_than_grad_acc = (\n                    steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch\n                )\n\n                if (\n                    total_batched_samples % args.gradient_accumulation_steps == 0\n                    or\n                    # last step in epoch but step is always smaller than gradient_accumulation_steps\n                    is_last_step_and_steps_less_than_grad_acc\n                ):\n                    # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered\n                    # in accelerate. So, explicitly enable sync gradients to True in that case.\n                    if is_last_step_and_steps_less_than_grad_acc or (\n                        version.parse(accelerate_version) <= version.parse(\"0.20.3\")\n                    ):\n                        self.accelerator.gradient_state._set_sync_gradients(True)\n\n                    # Gradient clipping\n                    if args.max_grad_norm is not None and args.max_grad_norm > 0:\n                        # deepspeed does its own clipping\n\n                        if self.do_grad_scaling:\n                            # Reduce gradients first for XLA\n                            if is_torch_tpu_available():\n                                gradients = xm._fetch_gradients(self.optimizer)\n                                xm.all_reduce(\"sum\", gradients, scale=1.0 / xm.xrt_world_size())\n                            # AMP: gradients need unscaling\n                            self.scaler.unscale_(self.optimizer)\n\n                        if is_sagemaker_mp_enabled() and args.fp16:\n                            self.optimizer.clip_master_grads(args.max_grad_norm)\n                        elif hasattr(self.optimizer, \"clip_grad_norm\"):\n                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping\n                            self.optimizer.clip_grad_norm(args.max_grad_norm)\n                        elif hasattr(model, \"clip_grad_norm_\"):\n                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping\n                            model.clip_grad_norm_(args.max_grad_norm)\n                        elif self.use_apex:\n                            # Revert to normal clipping otherwise, handling Apex or full precision\n                            nn.utils.clip_grad_norm_(\n                                amp.master_params(self.optimizer),\n                                args.max_grad_norm,\n                            )\n                        else:\n                            self.accelerator.clip_grad_norm_(\n                                model.parameters(),\n                                args.max_grad_norm,\n                            )\n\n                    # Optimizer step\n                    optimizer_was_run = True\n                    if is_torch_tpu_available():\n                        if self.do_grad_scaling:\n                            self.scaler.step(self.optimizer)\n                            self.scaler.update()\n                        else:\n                            # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step\n                            self.optimizer.step()\n                    elif self.do_grad_scaling:\n                        scale_before = self.scaler.get_scale()\n                        self.scaler.step(self.optimizer)\n                        self.scaler.update()\n                        scale_after = self.scaler.get_scale()\n                        optimizer_was_run = scale_before <= scale_after\n                    else:\n                        self.optimizer.step()\n                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped\n\n                    if optimizer_was_run:\n                        # Delay optimizer scheduling until metrics are generated\n                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                            self.lr_scheduler.step()\n\n                    model.zero_grad()\n                    self.state.global_step += 1\n                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch\n                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)\n\n                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n                else:\n                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)\n\n                if self.control.should_epoch_stop or self.control.should_training_stop:\n                    break\n            if step < 0:\n                logger.warning(\n                    \"There seems to be not a single sample in your epoch_iterator, stopping training at step\"\n                    f\" {self.state.global_step}! This is expected if you're using an IterableDataset and set\"\n                    f\" num_steps ({max_steps}) higher than the number of available samples.\"\n                )\n                self.control.should_training_stop = True\n\n            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)\n            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n\n            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n                if is_torch_tpu_available():\n                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n                    xm.master_print(met.metrics_report())\n                else:\n                    logger.warning(\n                        \"You enabled PyTorch/XLA debug metrics but you don't have a TPU \"\n                        \"configured. Check your training configuration if this is unexpected.\"\n                    )\n            if self.control.should_training_stop:\n                break\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of training\n            delattr(self, \"_past\")\n\n        logger.info(\"\\n\\nTraining completed. Do not forget to share your model on huggingface.co/models =)\\n\\n\")\n        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:\n            # Wait for everyone to get here so we are sur the model has been saved by process 0.\n            if is_torch_tpu_available():\n                xm.rendezvous(\"load_best_model_at_end\")\n            elif args.parallel_mode == ParallelMode.DISTRIBUTED:\n                dist.barrier()\n            elif is_sagemaker_mp_enabled():\n                smp.barrier()\n\n            self._load_best_model()\n\n        # add remaining tr_loss\n        # self._total_loss_scalar += tr_loss.item()\n        self._total_loss_scalar += tr_loss['loss_total'].item()\n\n        train_loss = self._total_loss_scalar / self.state.global_step\n\n        metrics = speed_metrics(\"train\", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)\n        self.store_flos()\n        metrics[\"total_flos\"] = self.state.total_flos\n        metrics[\"train_loss\"] = train_loss\n\n        self.is_in_train = False\n\n        self._memory_tracker.stop_and_update_metrics(metrics)\n\n        self.log(metrics)\n\n        run_dir = self._get_output_dir(trial)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)\n\n        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.\n        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:\n            for checkpoint in checkpoints_sorted:\n                if checkpoint != self.state.best_model_checkpoint:\n                    logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n                    shutil.rmtree(checkpoint)\n\n        self.control = self.callback_handler.on_train_end(args, self.state, self.control)\n\n        return TrainOutput(self.state.global_step, train_loss, metrics)\n\n    def _get_output_dir(self, trial):\n        if self.hp_search_backend is not None and trial is not None:\n            if self.hp_search_backend == HPSearchBackend.OPTUNA:\n                run_id = trial.number\n            elif self.hp_search_backend == HPSearchBackend.RAY:\n                from ray import tune\n\n                run_id = tune.get_trial_id()\n            elif self.hp_search_backend == HPSearchBackend.SIGOPT:\n                run_id = trial.id\n            elif self.hp_search_backend == HPSearchBackend.WANDB:\n                import wandb\n\n                run_id = wandb.run.id\n            run_name = self.hp_name(trial) if self.hp_name is not None else f\"run-{run_id}\"\n            run_dir = os.path.join(self.args.output_dir, run_name)\n        else:\n            run_dir = self.args.output_dir\n        return run_dir\n\n    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):\n        if model is None:\n            model = self.model\n\n        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)\n        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)\n        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)\n        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)\n        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)\n        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)\n        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)\n\n        if not any(\n            os.path.isfile(f)\n            for f in [\n                weights_file,\n                safe_weights_file,\n                weights_index_file,\n                safe_weights_index_file,\n                adapter_weights_file,\n                adapter_safe_weights_file,\n            ]\n        ):\n            raise ValueError(f\"Can't find a valid checkpoint at {resume_from_checkpoint}\")\n\n        logger.info(f\"Loading model from {resume_from_checkpoint}.\")\n        if os.path.isfile(config_file):\n            config = PretrainedConfig.from_json_file(config_file)\n            checkpoint_version = config.transformers_version\n            if checkpoint_version is not None and checkpoint_version != __version__:\n                logger.warning(\n                    f\"You are resuming training from a checkpoint trained with {checkpoint_version} of \"\n                    f\"Transformers but your current version is {__version__}. This is not recommended and could \"\n                    \"yield to errors or unwanted behaviors.\"\n                )\n        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):\n            # If the model is on the GPU, it still works!\n            if is_sagemaker_mp_enabled():\n                if os.path.isfile(os.path.join(resume_from_checkpoint, \"user_content.pt\")):\n                    # If the 'user_content.pt' file exists, load with the new smp api.\n                    # Checkpoint must have been saved with the new smp api.\n                    smp.resume_from_checkpoint(\n                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False\n                    )\n                else:\n                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.\n                    # Checkpoint must have been saved with the old smp api.\n                    if hasattr(self.args, \"fp16\") and self.args.fp16 is True:\n                        logger.warning(\n                            \"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported.\"\n                        )\n                    state_dict = torch.load(weights_file, map_location=\"cpu\")\n                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).\n                    state_dict[\"_smp_is_partial\"] = False\n                    load_result = model.load_state_dict(state_dict, strict=True)\n                    # release memory\n                    del state_dict\n            elif self.is_fsdp_enabled:\n                load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)\n            else:\n                # We load the model state dict on the CPU to avoid an OOM error.\n                if self.args.save_safetensors and os.path.isfile(safe_weights_file):\n                    state_dict = safetensors.torch.load_file(safe_weights_file, device=\"cpu\")\n                else:\n                    state_dict = torch.load(weights_file, map_location=\"cpu\")\n\n                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963\n                # which takes *args instead of **kwargs\n                load_result = model.load_state_dict(state_dict, False)\n                # release memory\n                del state_dict\n                self._issue_warnings_after_load(load_result)\n\n        # Load adapters following PR # 24096\n        elif is_peft_available() and isinstance(model, PeftModel):\n            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.\n            if hasattr(model, \"active_adapter\") and hasattr(model, \"load_adapter\"):\n                if os.path.exists(resume_from_checkpoint):\n                    model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)\n                else:\n                    logger.warning(\n                        \"The intermediate checkpoints of PEFT may not be saved correctly, \"\n                        f\"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. \"\n                        \"Check some examples here: https://github.com/huggingface/peft/issues/96\"\n                    )\n            else:\n                logger.warning(\"Could not load adapter model, make sure to have `peft>=0.3.0` installed\")\n        else:\n            # We load the sharded checkpoint\n            load_result = load_sharded_checkpoint(\n                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors\n            )\n            if not is_sagemaker_mp_enabled():\n                self._issue_warnings_after_load(load_result)\n\n    def _load_best_model(self):\n        logger.info(f\"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).\")\n        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)\n        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)\n        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)\n        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)\n\n        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model\n        if (\n            os.path.exists(best_model_path)\n            or os.path.exists(best_safe_model_path)\n            or os.path.exists(best_adapter_model_path)\n            or os.path.exists(best_safe_adapter_model_path)\n        ):\n            if self.is_deepspeed_enabled:\n                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)\n            else:\n                has_been_loaded = True\n                if is_sagemaker_mp_enabled():\n                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, \"user_content.pt\")):\n                        # If the 'user_content.pt' file exists, load with the new smp api.\n                        # Checkpoint must have been saved with the new smp api.\n                        smp.resume_from_checkpoint(\n                            path=self.state.best_model_checkpoint,\n                            tag=WEIGHTS_NAME,\n                            partial=False,\n                            load_optimizer=False,\n                        )\n                    else:\n                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.\n                        # Checkpoint must have been saved with the old smp api.\n                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):\n                            state_dict = safetensors.torch.load_file(best_safe_model_path, device=\"cpu\")\n                        else:\n                            state_dict = torch.load(best_model_path, map_location=\"cpu\")\n\n                        state_dict[\"_smp_is_partial\"] = False\n                        load_result = model.load_state_dict(state_dict, strict=True)\n                elif self.is_fsdp_enabled:\n                    load_fsdp_model(\n                        self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint\n                    )\n                else:\n                    if is_peft_available() and isinstance(model, PeftModel):\n                        # If train a model using PEFT & LoRA, assume that adapter have been saved properly.\n                        if hasattr(model, \"active_adapter\") and hasattr(model, \"load_adapter\"):\n                            if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):\n                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)\n                                # Load_adapter has no return value present, modify it when appropriate.\n                                from torch.nn.modules.module import _IncompatibleKeys\n\n                                load_result = _IncompatibleKeys([], [])\n                            else:\n                                logger.warning(\n                                    \"The intermediate checkpoints of PEFT may not be saved correctly, \"\n                                    f\"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. \"\n                                    \"Check some examples here: https://github.com/huggingface/peft/issues/96\"\n                                )\n                                has_been_loaded = False\n                        else:\n                            logger.warning(\"Could not load adapter model, make sure to have `peft>=0.3.0` installed\")\n                            has_been_loaded = False\n                    else:\n                        # We load the model state dict on the CPU to avoid an OOM error.\n                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):\n                            state_dict = safetensors.torch.load_file(best_safe_model_path, device=\"cpu\")\n                        else:\n                            state_dict = torch.load(best_model_path, map_location=\"cpu\")\n\n                        # If the model is on the GPU, it still works!\n                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963\n                        # which takes *args instead of **kwargs\n                        load_result = model.load_state_dict(state_dict, False)\n                if not is_sagemaker_mp_enabled() and has_been_loaded:\n                    self._issue_warnings_after_load(load_result)\n        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):\n            load_result = load_sharded_checkpoint(\n                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()\n            )\n            if not is_sagemaker_mp_enabled():\n                self._issue_warnings_after_load(load_result)\n        else:\n            logger.warning(\n                f\"Could not locate the best model at {best_model_path}, if you are running a distributed training \"\n                \"on multiple nodes, you should activate `--save_on_each_node`.\"\n            )\n\n    def _issue_warnings_after_load(self, load_result):\n        if len(load_result.missing_keys) != 0:\n            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(\n                self.model._keys_to_ignore_on_save\n            ):\n                self.model.tie_weights()\n            else:\n                logger.warning(f\"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.\")\n        if len(load_result.unexpected_keys) != 0:\n            logger.warning(\n                f\"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.\"\n            )\n\n    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):\n        if self.control.should_log:\n            if is_torch_tpu_available():\n                xm.mark_step()\n\n            logs: Dict[str, float] = {}\n\n            # all_gather + mean() to get average loss over all processes\n            # tr_loss_scalar = self._nested_gather(tr_loss).mean().item()\n            tr_loss_scalar = {k: self._nested_gather(tr_loss[k]).mean().item() for k in tr_loss.keys()}\n\n            # reset tr_loss to zero\n            for _,loss in tr_loss.items():\n                loss -= loss\n            # tr_loss -= tr_loss\n\n            # logs[\"loss\"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)\n            for k,loss in tr_loss_scalar.items():\n                logs[k]=round(loss / (self.state.global_step - self._globalstep_last_logged), 4)\n            logs[\"learning_rate\"] = self._get_learning_rate()\n\n            self._total_loss_scalar += tr_loss_scalar['loss_total']\n            self._globalstep_last_logged = self.state.global_step\n            self.store_flos()\n\n            self.log(logs)\n\n        metrics = None\n        if self.control.should_evaluate:\n            if isinstance(self.eval_dataset, dict):\n                metrics = {}\n                for eval_dataset_name, eval_dataset in self.eval_dataset.items():\n                    dataset_metrics = self.evaluate(\n                        eval_dataset=eval_dataset,\n                        ignore_keys=ignore_keys_for_eval,\n                        metric_key_prefix=f\"eval_{eval_dataset_name}\",\n                    )\n                    metrics.update(dataset_metrics)\n            else:\n                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)\n            self._report_to_hp_search(trial, self.state.global_step, metrics)\n\n            # Run delayed LR scheduler now that metrics are populated\n            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                metric_to_check = self.args.metric_for_best_model\n                if not metric_to_check.startswith(\"eval_\"):\n                    metric_to_check = f\"eval_{metric_to_check}\"\n                self.lr_scheduler.step(metrics[metric_to_check])\n\n        if self.control.should_save:\n            self._save_checkpoint(model, trial, metrics=metrics)\n            self.control = self.callback_handler.on_save(self.args, self.state, self.control)\n\n    def _load_rng_state(self, checkpoint):\n        # Load RNG states from `checkpoint`\n        if checkpoint is None:\n            return\n\n        if self.args.world_size > 1:\n            process_index = self.args.process_index\n            rng_file = os.path.join(checkpoint, f\"rng_state_{process_index}.pth\")\n            if not os.path.isfile(rng_file):\n                logger.info(\n                    f\"Didn't find an RNG file for process {process_index}, if you are resuming a training that \"\n                    \"wasn't launched in a distributed fashion, reproducibility is not guaranteed.\"\n                )\n                return\n        else:\n            rng_file = os.path.join(checkpoint, \"rng_state.pth\")\n            if not os.path.isfile(rng_file):\n                logger.info(\n                    \"Didn't find an RNG file, if you are resuming a training that was launched in a distributed \"\n                    \"fashion, reproducibility is not guaranteed.\"\n                )\n                return\n\n        checkpoint_rng_state = torch.load(rng_file)\n        random.setstate(checkpoint_rng_state[\"python\"])\n        np.random.set_state(checkpoint_rng_state[\"numpy\"])\n        torch.random.set_rng_state(checkpoint_rng_state[\"cpu\"])\n        if torch.cuda.is_available():\n            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n                torch.cuda.random.set_rng_state_all(checkpoint_rng_state[\"cuda\"])\n            else:\n                try:\n                    torch.cuda.random.set_rng_state(checkpoint_rng_state[\"cuda\"])\n                except Exception as e:\n                    logger.info(\n                        f\"Didn't manage to set back the RNG states of the GPU because of the following error:\\n {e}\"\n                        \"\\nThis won't yield the same results as if the training had not been interrupted.\"\n                    )\n        if is_torch_tpu_available():\n            xm.set_rng_state(checkpoint_rng_state[\"xla\"])\n\n    def _save_checkpoint(self, model, trial, metrics=None):\n        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we\n        # want to save except FullyShardedDDP.\n        # assert unwrap_model(model) is self.model, \"internal model should be a reference to self.model\"\n\n        # Save model checkpoint\n        checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n\n        if self.hp_search_backend is None and trial is None:\n            self.store_flos()\n\n        run_dir = self._get_output_dir(trial=trial)\n        output_dir = os.path.join(run_dir, checkpoint_folder)\n        self.save_model(output_dir, _internal_call=True)\n        if self.is_deepspeed_enabled:\n            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed\n            # config `stage3_gather_16bit_weights_on_model_save` is True\n            self.model_wrapped.save_checkpoint(output_dir)\n\n        # Save optimizer and scheduler\n        if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n            self.optimizer.consolidate_state_dict()\n\n        if self.fsdp or self.is_fsdp_enabled:\n            if self.is_fsdp_enabled:\n                save_fsdp_optimizer(\n                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir\n                )\n            else:\n                # FSDP has a different interface for saving optimizer states.\n                # Needs to be called on all ranks to gather all states.\n                # full_optim_state_dict will be deprecated after Pytorch 2.2!\n                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)\n\n        if is_torch_tpu_available():\n            xm.rendezvous(\"saving_optimizer_states\")\n            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n            with warnings.catch_warnings(record=True) as caught_warnings:\n                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n                reissue_pt_warnings(caught_warnings)\n        elif is_sagemaker_mp_enabled():\n            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)\n            smp.barrier()\n            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:\n                smp.save(\n                    opt_state_dict,\n                    os.path.join(output_dir, OPTIMIZER_NAME),\n                    partial=True,\n                    v3=smp.state.cfg.shard_optimizer_state,\n                )\n            if self.args.should_save:\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n                reissue_pt_warnings(caught_warnings)\n                if self.do_grad_scaling:\n                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))\n        elif self.args.should_save and not self.is_deepspeed_enabled:\n            # deepspeed.save_checkpoint above saves model/optim/sched\n            if self.fsdp and not self.is_fsdp_enabled:\n                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))\n            else:\n                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n\n            with warnings.catch_warnings(record=True) as caught_warnings:\n                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n            reissue_pt_warnings(caught_warnings)\n            if self.do_grad_scaling:\n                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))\n\n        # Determine the new best metric / best model checkpoint\n        if metrics is not None and self.args.metric_for_best_model is not None:\n            metric_to_check = self.args.metric_for_best_model\n            if not metric_to_check.startswith(\"eval_\"):\n                metric_to_check = f\"eval_{metric_to_check}\"\n            metric_value = metrics[metric_to_check]\n\n            operator = np.greater if self.args.greater_is_better else np.less\n            if (\n                self.state.best_metric is None\n                or self.state.best_model_checkpoint is None\n                or operator(metric_value, self.state.best_metric)\n            ):\n                self.state.best_metric = metric_value\n                self.state.best_model_checkpoint = output_dir\n\n        # Save the Trainer state\n        if self.args.should_save:\n            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))\n\n        # Save RNG state in non-distributed training\n        rng_states = {\n            \"python\": random.getstate(),\n            \"numpy\": np.random.get_state(),\n            \"cpu\": torch.random.get_rng_state(),\n        }\n        if torch.cuda.is_available():\n            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)\n                rng_states[\"cuda\"] = torch.cuda.random.get_rng_state_all()\n            else:\n                rng_states[\"cuda\"] = torch.cuda.random.get_rng_state()\n\n        if is_torch_tpu_available():\n            rng_states[\"xla\"] = xm.get_rng_state()\n\n        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may\n        # not yet exist.\n        os.makedirs(output_dir, exist_ok=True)\n\n        if self.args.world_size <= 1:\n            torch.save(rng_states, os.path.join(output_dir, \"rng_state.pth\"))\n        else:\n            torch.save(rng_states, os.path.join(output_dir, f\"rng_state_{self.args.process_index}.pth\"))\n\n        if self.args.push_to_hub:\n            self._push_from_checkpoint(output_dir)\n\n        # Maybe delete some older checkpoints.\n        if self.args.should_save:\n            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)\n\n    def _load_optimizer_and_scheduler(self, checkpoint):\n        \"\"\"If optimizer and scheduler states exist, load them.\"\"\"\n        if checkpoint is None:\n            return\n\n        if self.is_deepspeed_enabled:\n            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init\n            return\n\n        checkpoint_file_exists = (\n            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + \"_*\")\n            if is_sagemaker_mp_enabled()\n            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))\n        )\n        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):\n            # Load in optimizer and scheduler states\n            if is_torch_tpu_available():\n                # On TPU we have to take some extra precautions to properly load the states on the right device.\n                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=\"cpu\")\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location=\"cpu\")\n                reissue_pt_warnings(caught_warnings)\n\n                xm.send_cpu_data_to_device(optimizer_state, self.args.device)\n                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)\n\n                self.optimizer.load_state_dict(optimizer_state)\n                self.lr_scheduler.load_state_dict(lr_scheduler_state)\n            else:\n                if is_sagemaker_mp_enabled():\n                    if os.path.isfile(os.path.join(checkpoint, \"user_content.pt\")):\n                        # Optimizer checkpoint was saved with smp >= 1.10\n                        def opt_load_hook(mod, opt):\n                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))\n\n                    else:\n                        # Optimizer checkpoint was saved with smp < 1.10\n                        def opt_load_hook(mod, opt):\n                            if IS_SAGEMAKER_MP_POST_1_10:\n                                opt.load_state_dict(\n                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)\n                                )\n                            else:\n                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))\n\n                    self.model_wrapped.register_post_step_hook(opt_load_hook)\n                else:\n                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.\n                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more\n                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state\n                    map_location = self.args.device if self.args.world_size > 1 else \"cpu\"\n                    if self.fsdp or self.is_fsdp_enabled:\n                        if self.is_fsdp_enabled:\n                            load_fsdp_optimizer(\n                                self.accelerator.state.fsdp_plugin,\n                                self.accelerator,\n                                self.optimizer,\n                                self.model,\n                                checkpoint,\n                            )\n                        else:\n                            full_osd = None\n                            # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it\n                            if self.args.process_index == 0:\n                                full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))\n                            # call scatter_full_optim_state_dict on all ranks\n                            sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)\n                            self.optimizer.load_state_dict(sharded_osd)\n                    else:\n                        self.optimizer.load_state_dict(\n                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)\n                        )\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))\n                reissue_pt_warnings(caught_warnings)\n                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):\n                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))\n\n    def hyperparameter_search(\n        self,\n        hp_space: Optional[Callable[[\"optuna.Trial\"], Dict[str, float]]] = None,\n        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,\n        n_trials: int = 20,\n        direction: str = \"minimize\",\n        backend: Optional[Union[\"str\", HPSearchBackend]] = None,\n        hp_name: Optional[Callable[[\"optuna.Trial\"], str]] = None,\n        **kwargs,\n    ) -> BestRun:\n        \"\"\"\n        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined\n        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,\n        the sum of all metrics otherwise.\n\n        <Tip warning={true}>\n\n        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to\n        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to\n        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom\n        optimizer/scheduler.\n\n        </Tip>\n\n        Args:\n            hp_space (`Callable[[\"optuna.Trial\"], Dict[str, float]]`, *optional*):\n                A function that defines the hyperparameter search space. Will default to\n                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or\n                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.\n            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):\n                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`\n                method. Will default to [`~trainer_utils.default_compute_objective`].\n            n_trials (`int`, *optional*, defaults to 100):\n                The number of trial runs to test.\n            direction (`str`, *optional*, defaults to `\"minimize\"`):\n                Whether to optimize greater or lower objects. Can be `\"minimize\"` or `\"maximize\"`, you should pick\n                `\"minimize\"` when optimizing the validation loss, `\"maximize\"` when optimizing one or several metrics.\n            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):\n                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending\n                on which one is installed. If all are installed, will default to optuna.\n            hp_name (`Callable[[\"optuna.Trial\"], str]]`, *optional*):\n                A function that defines the trial/run name. Will default to None.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more\n                information see:\n\n                - the documentation of\n                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)\n                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)\n                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)\n\n        Returns:\n            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in\n            `run_summary` attribute for Ray backend.\n        \"\"\"\n        if backend is None:\n            backend = default_hp_search_backend()\n        backend = HPSearchBackend(backend)\n        backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()\n        backend_obj.ensure_available()\n        self.hp_search_backend = backend\n        if self.model_init is None:\n            raise RuntimeError(\n                \"To use hyperparameter search, you need to pass your model through a model_init function.\"\n            )\n\n        self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space\n        self.hp_name = hp_name\n        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective\n\n        best_run = backend_obj.run(self, n_trials, direction, **kwargs)\n\n        self.hp_search_backend = None\n        return best_run\n\n    def log(self, logs: Dict[str, float]) -> None:\n        \"\"\"\n        Log `logs` on the various objects watching training.\n\n        Subclass and override this method to inject custom behavior.\n\n        Args:\n            logs (`Dict[str, float]`):\n                The values to log.\n        \"\"\"\n        if self.state.epoch is not None:\n            logs[\"epoch\"] = round(self.state.epoch, 2)\n\n        output = {**logs, **{\"step\": self.state.global_step}}\n        self.state.log_history.append(output)\n        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)\n\n    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:\n        \"\"\"\n        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.\n        \"\"\"\n        if isinstance(data, Mapping):\n            return type(data)({k: self._prepare_input(v) for k, v in data.items()})\n        elif isinstance(data, (tuple, list)):\n            return type(data)(self._prepare_input(v) for v in data)\n        elif isinstance(data, torch.Tensor):\n            kwargs = {\"device\": self.args.device}\n            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):\n                # NLP models inputs are int/uint and those get adjusted to the right dtype of the\n                # embedding. Other models such as wav2vec2's inputs are already float and thus\n                # may need special handling to match the dtypes of the model\n                kwargs.update({\"dtype\": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})\n            return data.to(**kwargs)\n        return data\n\n    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:\n        \"\"\"\n        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and\n        handling potential state.\n        \"\"\"\n        inputs = self._prepare_input(inputs)\n        if len(inputs) == 0:\n            raise ValueError(\n                \"The batch received was empty, your model won't be able to train on it. Double-check that your \"\n                f\"training dataset contains keys expected by the model: {','.join(self._signature_columns)}.\"\n            )\n        if self.args.past_index >= 0 and self._past is not None:\n            inputs[\"mems\"] = self._past\n\n        return inputs\n\n    def compute_loss_context_manager(self):\n        \"\"\"\n        A helper wrapper to group together context managers.\n        \"\"\"\n        return self.autocast_smart_context_manager()\n\n    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):\n        \"\"\"\n        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired\n        arguments, depending on the situation.\n        \"\"\"\n        if self.use_cuda_amp or self.use_cpu_amp:\n            ctx_manager = (\n                torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)\n                if self.use_cpu_amp\n                else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)\n            )\n        else:\n            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()\n\n        return ctx_manager\n\n    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:\n        \"\"\"\n        Perform a training step on a batch of inputs.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to train.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n\n        Return:\n            `torch.Tensor`: The tensor with training loss on this batch.\n        \"\"\"\n        model.train()\n        inputs = self._prepare_inputs(inputs)\n\n        if is_sagemaker_mp_enabled():\n            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)\n            return loss_mb.reduce_mean().detach().to(self.args.device)\n\n        with self.compute_loss_context_manager():\n            loss = self.compute_loss(model, inputs)\n\n\n        if self.args.n_gpu > 1:\n            for k, ls in loss.items():\n                loss[k] = loss[k].mean()  # mean() to average on multi-gpu parallel training\n\n        if self.do_grad_scaling:\n            self.scaler.scale(loss['loss_total']).backward()\n        elif self.use_apex:\n            with amp.scale_loss(loss['loss_total'], self.optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            self.accelerator.backward(loss['loss_total'])\n\n        # return loss.detach() / self.args.gradient_accumulation_steps\n        return {k:v.detach()/self.args.gradient_accumulation_steps for k,v in loss.items()}\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n        \"\"\"\n        How the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n        Subclass and override for custom behavior.\n        \"\"\"\n        if self.label_smoother is not None and \"labels\" in inputs:\n            labels = inputs.pop(\"labels\")\n        else:\n            labels = None\n        outputs = model(**inputs)\n        # Save past state if it exists\n        # TODO: this needs to be fixed and made cleaner later.\n        if self.args.past_index >= 0:\n            self._past = outputs[self.args.past_index]\n\n        if labels is not None:\n            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():\n                loss = self.label_smoother(outputs, labels, shift_labels=True)\n            else:\n                loss = self.label_smoother(outputs, labels)\n        else:\n            if isinstance(outputs, dict) and \"loss\" not in outputs:\n                raise ValueError(\n                    \"The model did not return a loss from the inputs, only the following keys: \"\n                    f\"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.\"\n                )\n            # We don't use .loss here since the model may return tuples instead of ModelOutput.\n            loss = outputs[\"loss\"] if isinstance(outputs, dict) else outputs[0]\n\n        return (loss, outputs) if return_outputs else loss\n\n    def is_local_process_zero(self) -> bool:\n        \"\"\"\n        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several\n        machines) main process.\n        \"\"\"\n        return self.args.local_process_index == 0\n\n    def is_world_process_zero(self) -> bool:\n        \"\"\"\n        Whether or not this process is the global main process (when training in a distributed fashion on several\n        machines, this is only going to be `True` for one process).\n        \"\"\"\n        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global\n        # process index.\n        if is_sagemaker_mp_enabled():\n            return smp.rank() == 0\n        else:\n            return self.args.process_index == 0\n\n    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):\n        \"\"\"\n        Will save the model, so you can reload it using `from_pretrained()`.\n\n        Will only save from the main process.\n        \"\"\"\n\n        if output_dir is None:\n            output_dir = self.args.output_dir\n\n        if is_torch_tpu_available():\n            self._save_tpu(output_dir)\n        elif is_sagemaker_mp_enabled():\n            # Calling the state_dict needs to be done on the wrapped model and on all processes.\n            os.makedirs(output_dir, exist_ok=True)\n            state_dict = self.model_wrapped.state_dict()\n            if self.args.should_save:\n                self._save(output_dir, state_dict=state_dict)\n            if IS_SAGEMAKER_MP_POST_1_10:\n                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10\n                Path(os.path.join(output_dir, \"user_content.pt\")).touch()\n        elif (\n            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp\n            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp\n            or self.fsdp is not None\n            or self.is_fsdp_enabled\n        ):\n            state_dict = self.model.state_dict()\n            if self.args.should_save:\n                self._save(output_dir, state_dict=state_dict)\n            if self.is_fsdp_enabled:\n                save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)\n\n        elif self.is_deepspeed_enabled:\n            # this takes care of everything as long as we aren't under zero3\n            if version.parse(accelerate_version) <= version.parse(\"0.20.3\"):\n                raise ValueError(\"Install Accelerate from main branch\")\n            try:\n                state_dict = self.accelerator.get_state_dict(self.deepspeed)\n                if self.args.should_save:\n                    self._save(output_dir, state_dict=state_dict)\n            except ValueError:\n                logger.warning(\n                    \" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use\"\n                    \" zero_to_fp32.py to recover weights\"\n                )\n                self.model_wrapped.save_checkpoint(output_dir)\n\n        elif self.args.should_save:\n            self._save(output_dir)\n\n        # Push to the Hub when `save_model` is called by the user.\n        if self.args.push_to_hub and not _internal_call:\n            self.push_to_hub(commit_message=\"Model save\")\n\n    def _save_tpu(self, output_dir: Optional[str] = None):\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        logger.info(f\"Saving model checkpoint to {output_dir}\")\n\n        if xm.is_master_ordinal():\n            os.makedirs(output_dir, exist_ok=True)\n            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n        # Save a trained model and configuration using `save_pretrained()`.\n        # They can then be reloaded using `from_pretrained()`\n        xm.rendezvous(\"saving_checkpoint\")\n        if not isinstance(self.model, PreTrainedModel):\n            if isinstance(unwrap_model(self.model), PreTrainedModel):\n                unwrap_model(self.model).save_pretrained(\n                    output_dir,\n                    is_main_process=self.args.should_save,\n                    state_dict=self.model.state_dict(),\n                    save_function=xm.save,\n                )\n            else:\n                logger.info(\"Trainer.model is not a `PreTrainedModel`, only saving its state dict.\")\n                state_dict = self.model.state_dict()\n                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n        else:\n            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)\n        if self.tokenizer is not None and self.args.should_save:\n            self.tokenizer.save_pretrained(output_dir)\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        # If we are executing this function, we are the process zero, so we don't check for that.\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n        logger.info(f\"Saving model checkpoint to {output_dir}\")\n\n        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)\n        # Save a trained model and configuration using `save_pretrained()`.\n        # They can then be reloaded using `from_pretrained()`\n        if not isinstance(self.model, supported_classes):\n            if state_dict is None:\n                state_dict = self.model.state_dict()\n\n            if isinstance(unwrap_model(self.model), supported_classes):\n                unwrap_model(self.model).save_pretrained(\n                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors\n                )\n            else:\n                logger.info(\"Trainer.model is not a `PreTrainedModel`, only saving its state dict.\")\n                if self.args.save_safetensors:\n                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))\n                else:\n                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n        else:\n            self.model.save_pretrained(\n                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors\n            )\n\n        if self.tokenizer is not None:\n            self.tokenizer.save_pretrained(output_dir)\n\n        # Good practice: save your training arguments together with the trained model\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n    def store_flos(self):\n        # Storing the number of floating-point operations that went into the model\n        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            self.state.total_flos += (\n                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()\n            )\n            self.current_flos = 0\n        else:\n            self.state.total_flos += self.current_flos\n            self.current_flos = 0\n\n    def _sorted_checkpoints(\n        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False\n    ) -> List[str]:\n        ordering_and_checkpoint_path = []\n\n        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f\"{checkpoint_prefix}-*\") if os.path.isdir(x)]\n\n        for path in glob_checkpoints:\n            if use_mtime:\n                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n            else:\n                regex_match = re.match(f\".*{checkpoint_prefix}-([0-9]+)\", path)\n                if regex_match is not None and regex_match.groups() is not None:\n                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n\n        checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n        # Make sure we don't delete the best model.\n        if self.state.best_model_checkpoint is not None:\n            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))\n            for i in range(best_model_index, len(checkpoints_sorted) - 2):\n                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]\n        return checkpoints_sorted\n\n    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:\n        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:\n            return\n\n        # Check if we should delete older checkpoint(s)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)\n        if len(checkpoints_sorted) <= self.args.save_total_limit:\n            return\n\n        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which\n        # we don't do to allow resuming.\n        save_total_limit = self.args.save_total_limit\n        if (\n            self.state.best_model_checkpoint is not None\n            and self.args.save_total_limit == 1\n            and checkpoints_sorted[-1] != self.state.best_model_checkpoint\n        ):\n            save_total_limit = 2\n\n        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)\n        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n        for checkpoint in checkpoints_to_be_deleted:\n            logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n            shutil.rmtree(checkpoint, ignore_errors=True)\n\n    def evaluate(\n        self,\n        eval_dataset: Optional[Dataset] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> Dict[str, float]:\n        \"\"\"\n        Run evaluation and returns metrics.\n\n        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent\n        (pass it to the init `compute_metrics` argument).\n\n        You can also subclass and override this method to inject custom behavior.\n\n        Args:\n            eval_dataset (`Dataset`, *optional*):\n                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns\n                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`\n                method.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n            metric_key_prefix (`str`, *optional*, defaults to `\"eval\"`):\n                An optional prefix to be used as the metrics key prefix. For example the metrics \"bleu\" will be named\n                \"eval_bleu\" if the prefix is \"eval\" (default)\n\n        Returns:\n            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The\n            dictionary also contains the epoch number which comes from the training state.\n        \"\"\"\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        eval_dataloader = self.get_eval_dataloader(eval_dataset)\n        start_time = time.time()\n\n        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop\n        output = eval_loop(\n            eval_dataloader,\n            description=\"Evaluation\",\n            # No point gathering the predictions if there are no metrics, otherwise we defer to\n            # self.args.prediction_loss_only\n            prediction_loss_only=True if self.compute_metrics is None else None,\n            ignore_keys=ignore_keys,\n            metric_key_prefix=metric_key_prefix,\n        )\n\n        total_batch_size = self.args.eval_batch_size * self.args.world_size\n        if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics:\n            start_time += output.metrics[f\"{metric_key_prefix}_jit_compilation_time\"]\n        output.metrics.update(\n            speed_metrics(\n                metric_key_prefix,\n                start_time,\n                num_samples=output.num_samples,\n                num_steps=math.ceil(output.num_samples / total_batch_size),\n            )\n        )\n\n        self.log(output.metrics)\n\n        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n            xm.master_print(met.metrics_report())\n\n        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)\n\n        self._memory_tracker.stop_and_update_metrics(output.metrics)\n\n        return output.metrics\n\n    def predict(\n        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = \"test\"\n    ) -> PredictionOutput:\n        \"\"\"\n        Run prediction and returns predictions and potential metrics.\n\n        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method\n        will also return metrics, like in `evaluate()`.\n\n        Args:\n            test_dataset (`Dataset`):\n                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the\n                `model.forward()` method are automatically removed. Has to implement the method `__len__`\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n            metric_key_prefix (`str`, *optional*, defaults to `\"test\"`):\n                An optional prefix to be used as the metrics key prefix. For example the metrics \"bleu\" will be named\n                \"test_bleu\" if the prefix is \"test\" (default)\n\n        <Tip>\n\n        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding\n        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into\n        one array. The padding index is -100.\n\n        </Tip>\n\n        Returns: *NamedTuple* A namedtuple with the following keys:\n\n            - predictions (`np.ndarray`): The predictions on `test_dataset`.\n            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).\n            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained\n              labels).\n        \"\"\"\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        test_dataloader = self.get_test_dataloader(test_dataset)\n        start_time = time.time()\n\n        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop\n        output = eval_loop(\n            test_dataloader, description=\"Prediction\", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix\n        )\n        total_batch_size = self.args.eval_batch_size * self.args.world_size\n        if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics:\n            start_time += output.metrics[f\"{metric_key_prefix}_jit_compilation_time\"]\n        output.metrics.update(\n            speed_metrics(\n                metric_key_prefix,\n                start_time,\n                num_samples=output.num_samples,\n                num_steps=math.ceil(output.num_samples / total_batch_size),\n            )\n        )\n\n        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)\n        self._memory_tracker.stop_and_update_metrics(output.metrics)\n\n        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)\n\n    def evaluation_loop(\n        self,\n        dataloader: DataLoader,\n        description: str,\n        prediction_loss_only: Optional[bool] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> EvalLoopOutput:\n        \"\"\"\n        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.\n\n        Works both with or without labels.\n        \"\"\"\n        args = self.args\n\n        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only\n\n        # if eval is called w/o train, handle model prep here\n        if self.is_deepspeed_enabled and self.deepspeed is None:\n            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)\n\n        model = self._wrap_model(self.model, training=False, dataloader=dataloader)\n\n        if len(self.accelerator._models) == 0 and model is self.model:\n            model = (\n                self.accelerator.prepare(model)\n                if self.is_deepspeed_enabled\n                else self.accelerator.prepare_model(model, evaluation_mode=True)\n            )\n\n            if self.is_fsdp_enabled:\n                self.model = model\n\n            # for the rest of this function `model` is the outside model, whether it was wrapped or not\n            if model is not self.model:\n                self.model_wrapped = model\n\n            # backward compatibility\n            if self.is_deepspeed_enabled:\n                self.deepspeed = self.model_wrapped\n\n        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called\n        # while ``train`` is running, cast it to the right dtype first and then put on device\n        if not self.is_in_train:\n            if args.fp16_full_eval:\n                model = model.to(dtype=torch.float16, device=args.device)\n            elif args.bf16_full_eval:\n                model = model.to(dtype=torch.bfloat16, device=args.device)\n\n        batch_size = self.args.eval_batch_size\n\n        logger.info(f\"***** Running {description} *****\")\n        if has_length(dataloader):\n            logger.info(f\"  Num examples = {self.num_examples(dataloader)}\")\n        else:\n            logger.info(\"  Num examples: Unknown\")\n        logger.info(f\"  Batch size = {batch_size}\")\n\n        model.eval()\n\n        self.callback_handler.eval_dataloader = dataloader\n        # Do this before wrapping.\n        eval_dataset = getattr(dataloader, \"dataset\", None)\n\n        if args.past_index >= 0:\n            self._past = None\n\n        # Initialize containers\n        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)\n        losses_host = None\n        preds_host = None\n        labels_host = None\n        inputs_host = None\n\n        # losses/preds/labels on CPU (final containers)\n        all_losses = None\n        all_preds = None\n        all_labels = None\n        all_inputs = None\n        # Will be useful when we have an iterable dataset so don't know its length.\n\n        observed_num_examples = 0\n        # Main evaluation loop\n        for step, inputs in enumerate(dataloader):\n            # Update the observed num examples\n            observed_batch_size = find_batch_size(inputs)\n            if observed_batch_size is not None:\n                observed_num_examples += observed_batch_size\n                # For batch samplers, batch_size is not known by the dataloader in advance.\n                if batch_size is None:\n                    batch_size = observed_batch_size\n\n            # Prediction step\n            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)\n            inputs_decode = self._prepare_input(inputs[\"input_ids\"]) if args.include_inputs_for_metrics else None\n\n            if is_torch_tpu_available():\n                xm.mark_step()\n\n            # Update containers on host\n            if loss is not None:\n                losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))\n                losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)\n            if labels is not None:\n                labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n            if inputs_decode is not None:\n                inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)\n                inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))\n                inputs_host = (\n                    inputs_decode\n                    if inputs_host is None\n                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)\n                )\n            if logits is not None:\n                logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)\n                if self.preprocess_logits_for_metrics is not None:\n                    logits = self.preprocess_logits_for_metrics(logits, labels)\n                logits = self.accelerator.gather_for_metrics((logits))\n                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)\n\n            if labels is not None:\n                labels = self.accelerator.gather_for_metrics((labels))\n                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)\n\n            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)\n\n            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.\n            if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:\n                if losses_host is not None:\n                    losses = nested_numpify(losses_host)\n                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)\n                if preds_host is not None:\n                    logits = nested_numpify(preds_host)\n                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)\n                if inputs_host is not None:\n                    inputs_decode = nested_numpify(inputs_host)\n                    all_inputs = (\n                        inputs_decode\n                        if all_inputs is None\n                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)\n                    )\n                if labels_host is not None:\n                    labels = nested_numpify(labels_host)\n                    all_labels = (\n                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)\n                    )\n\n                # Set back to None to begin a new accumulation\n                losses_host, preds_host, inputs_host, labels_host = None, None, None, None\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of the evaluation loop\n            delattr(self, \"_past\")\n\n        # Gather all remaining tensors and put them back on the CPU\n        if losses_host is not None:\n            losses = nested_numpify(losses_host)\n            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)\n        if preds_host is not None:\n            logits = nested_numpify(preds_host)\n            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)\n        if inputs_host is not None:\n            inputs_decode = nested_numpify(inputs_host)\n            all_inputs = (\n                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)\n            )\n        if labels_host is not None:\n            labels = nested_numpify(labels_host)\n            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)\n\n        # Number of samples\n        if has_length(eval_dataset):\n            num_samples = len(eval_dataset)\n        # The instance check is weird and does not actually check for the type, but whether the dataset has the right\n        # methods. Therefore we need to make sure it also has the attribute.\n        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, \"num_examples\", 0) > 0:\n            num_samples = eval_dataset.num_examples\n        else:\n            if has_length(dataloader):\n                num_samples = self.num_examples(dataloader)\n            else:  # both len(dataloader.dataset) and len(dataloader) fail\n                num_samples = observed_num_examples\n        if num_samples == 0 and observed_num_examples > 0:\n            num_samples = observed_num_examples\n\n        # Metrics!\n        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:\n            if args.include_inputs_for_metrics:\n                metrics = self.compute_metrics(\n                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)\n                )\n            else:\n                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))\n        else:\n            metrics = {}\n\n        # To be JSON-serializable, we need to remove numpy types or zero-d tensors\n        metrics = denumpify_detensorize(metrics)\n\n        if all_losses is not None:\n            metrics[f\"{metric_key_prefix}_loss\"] = all_losses.mean().item()\n        if hasattr(self, \"jit_compilation_time\"):\n            metrics[f\"{metric_key_prefix}_jit_compilation_time\"] = self.jit_compilation_time\n\n        # Prefix all keys with metric_key_prefix + '_'\n        for key in list(metrics.keys()):\n            if not key.startswith(f\"{metric_key_prefix}_\"):\n                metrics[f\"{metric_key_prefix}_{key}\"] = metrics.pop(key)\n\n        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)\n\n    def _nested_gather(self, tensors, name=None):\n        \"\"\"\n        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before\n        concatenating them to `gathered`\n        \"\"\"\n        if tensors is None:\n            return\n        if is_torch_tpu_available():\n            if name is None:\n                name = \"nested_gather\"\n            tensors = nested_xla_mesh_reduce(tensors, name)\n        elif is_sagemaker_mp_enabled():\n            tensors = smp_gather(tensors)\n        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != \"NO\") or (\n            self.args.distributed_state is None and self.args.local_rank != -1\n        ):\n            tensors = distributed_concat(tensors)\n        return tensors\n\n    def prediction_step(\n        self,\n        model: nn.Module,\n        inputs: Dict[str, Union[torch.Tensor, Any]],\n        prediction_loss_only: bool,\n        ignore_keys: Optional[List[str]] = None,\n    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:\n        \"\"\"\n        Perform an evaluation step on `model` using `inputs`.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to evaluate.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n            prediction_loss_only (`bool`):\n                Whether or not to return the loss only.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n\n        Return:\n            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,\n            logits and labels (each being optional).\n        \"\"\"\n        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)\n        # For CLIP-like models capable of returning loss values.\n        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`\n        # is `True` in `model.forward`.\n        return_loss = inputs.get(\"return_loss\", None)\n        if return_loss is None:\n            return_loss = self.can_return_loss\n        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False\n\n        inputs = self._prepare_inputs(inputs)\n        if ignore_keys is None:\n            if hasattr(self.model, \"config\"):\n                ignore_keys = getattr(self.model.config, \"keys_to_ignore_at_inference\", [])\n            else:\n                ignore_keys = []\n\n        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.\n        if has_labels or loss_without_labels:\n            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))\n            if len(labels) == 1:\n                labels = labels[0]\n        else:\n            labels = None\n\n        with torch.no_grad():\n            if is_sagemaker_mp_enabled():\n                raw_outputs = smp_forward_only(model, inputs)\n                if has_labels or loss_without_labels:\n                    if isinstance(raw_outputs, dict):\n                        loss_mb = raw_outputs[\"loss\"]\n                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + [\"loss\"])\n                    else:\n                        loss_mb = raw_outputs[0]\n                        logits_mb = raw_outputs[1:]\n\n                    loss = loss_mb.reduce_mean().detach().cpu()\n                    logits = smp_nested_concat(logits_mb)\n                else:\n                    loss = None\n                    if isinstance(raw_outputs, dict):\n                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)\n                    else:\n                        logits_mb = raw_outputs\n                    logits = smp_nested_concat(logits_mb)\n            else:\n                if has_labels or loss_without_labels:\n                    with self.compute_loss_context_manager():\n                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)\n                    loss = loss.mean().detach()\n\n                    if isinstance(outputs, dict):\n                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + [\"loss\"])\n                    else:\n                        logits = outputs[1:]\n                else:\n                    loss = None\n                    with self.compute_loss_context_manager():\n                        outputs = model(**inputs)\n                    if isinstance(outputs, dict):\n                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)\n                    else:\n                        logits = outputs\n                    # TODO: this needs to be fixed and made cleaner later.\n                    if self.args.past_index >= 0:\n                        self._past = outputs[self.args.past_index - 1]\n\n        if prediction_loss_only:\n            return (loss, None, None)\n\n        logits = nested_detach(logits)\n        if len(logits) == 1:\n            logits = logits[0]\n\n        return (loss, logits, labels)\n\n    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):\n        \"\"\"\n        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point\n        operations for every backward + forward pass. If using another model, either implement such a method in the\n        model or subclass and override this method.\n\n        Args:\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n        Returns:\n            `int`: The number of floating-point operations.\n        \"\"\"\n        if hasattr(self.model, \"floating_point_ops\"):\n            return self.model.floating_point_ops(inputs)\n        else:\n            return 0\n\n    def init_git_repo(self, at_init: bool = False):\n        \"\"\"\n        Initializes a git repo in `self.args.hub_model_id`.\n\n        Args:\n            at_init (`bool`, *optional*, defaults to `False`):\n                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is\n                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped\n                out.\n        \"\"\"\n        if not self.is_world_process_zero():\n            return\n        if self.args.hub_model_id is None:\n            repo_name = Path(self.args.output_dir).absolute().name\n        else:\n            repo_name = self.args.hub_model_id\n        if \"/\" not in repo_name:\n            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)\n\n        # Make sure the repo exists.\n        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)\n        try:\n            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)\n        except EnvironmentError:\n            if self.args.overwrite_output_dir and at_init:\n                # Try again after wiping output_dir\n                shutil.rmtree(self.args.output_dir)\n                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)\n            else:\n                raise\n\n        self.repo.git_pull()\n\n        # By default, ignore the checkpoint folders\n        if (\n            not os.path.exists(os.path.join(self.args.output_dir, \".gitignore\"))\n            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS\n        ):\n            with open(os.path.join(self.args.output_dir, \".gitignore\"), \"w\", encoding=\"utf-8\") as writer:\n                writer.writelines([\"checkpoint-*/\"])\n\n        # Add \"*.sagemaker\" to .gitignore if using SageMaker\n        if os.environ.get(\"SM_TRAINING_ENV\"):\n            self._add_sm_patterns_to_gitignore()\n\n        self.push_in_progress = None\n\n    def create_model_card(\n        self,\n        language: Optional[str] = None,\n        license: Optional[str] = None,\n        tags: Union[str, List[str], None] = None,\n        model_name: Optional[str] = None,\n        finetuned_from: Optional[str] = None,\n        tasks: Union[str, List[str], None] = None,\n        dataset_tags: Union[str, List[str], None] = None,\n        dataset: Union[str, List[str], None] = None,\n        dataset_args: Union[str, List[str], None] = None,\n    ):\n        \"\"\"\n        Creates a draft of a model card using the information available to the `Trainer`.\n\n        Args:\n            language (`str`, *optional*):\n                The language of the model (if applicable)\n            license (`str`, *optional*):\n                The license of the model. Will default to the license of the pretrained model used, if the original\n                model given to the `Trainer` comes from a repo on the Hub.\n            tags (`str` or `List[str]`, *optional*):\n                Some tags to be included in the metadata of the model card.\n            model_name (`str`, *optional*):\n                The name of the model.\n            finetuned_from (`str`, *optional*):\n                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo\n                of the original model given to the `Trainer` (if it comes from the Hub).\n            tasks (`str` or `List[str]`, *optional*):\n                One or several task identifiers, to be included in the metadata of the model card.\n            dataset_tags (`str` or `List[str]`, *optional*):\n                One or several dataset tags, to be included in the metadata of the model card.\n            dataset (`str` or `List[str]`, *optional*):\n                One or several dataset identifiers, to be included in the metadata of the model card.\n            dataset_args (`str` or `List[str]`, *optional*):\n               One or several dataset arguments, to be included in the metadata of the model card.\n        \"\"\"\n        if not self.is_world_process_zero():\n            return\n\n        training_summary = TrainingSummary.from_trainer(\n            self,\n            language=language,\n            license=license,\n            tags=tags,\n            model_name=model_name,\n            finetuned_from=finetuned_from,\n            tasks=tasks,\n            dataset_tags=dataset_tags,\n            dataset=dataset,\n            dataset_args=dataset_args,\n        )\n        model_card = training_summary.to_model_card()\n        with open(os.path.join(self.args.output_dir, \"README.md\"), \"w\") as f:\n            f.write(model_card)\n\n    def _push_from_checkpoint(self, checkpoint_folder):\n        # Only push from one node.\n        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:\n            return\n        # If we haven't finished the last push, we don't do this one.\n        if self.push_in_progress is not None and not self.push_in_progress.is_done:\n            return\n\n        output_dir = self.args.output_dir\n        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder\n        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]\n        if is_peft_available():\n            modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])\n        for modeling_file in modeling_files:\n            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):\n                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))\n        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.\n        if self.tokenizer is not None:\n            self.tokenizer.save_pretrained(output_dir)\n        # Same for the training arguments\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n        try:\n            if self.args.hub_strategy == HubStrategy.CHECKPOINT:\n                # Temporarily move the checkpoint just saved for the push\n                tmp_checkpoint = os.path.join(output_dir, \"last-checkpoint\")\n                # We have to remove the \"last-checkpoint\" dir if it exists, otherwise the checkpoint is moved as a\n                # subfolder.\n                if os.path.isdir(tmp_checkpoint):\n                    shutil.rmtree(tmp_checkpoint)\n                shutil.move(checkpoint_folder, tmp_checkpoint)\n\n            if self.args.save_strategy == IntervalStrategy.STEPS:\n                commit_message = f\"Training in progress, step {self.state.global_step}\"\n            else:\n                commit_message = f\"Training in progress, epoch {int(self.state.epoch)}\"\n            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)\n            # Return type of `Repository.push_to_hub` is either None or a tuple.\n            if push_work is not None:\n                self.push_in_progress = push_work[1]\n        except Exception as e:\n            logger.error(f\"Error when pushing to hub: {e}\")\n        finally:\n            if self.args.hub_strategy == HubStrategy.CHECKPOINT:\n                # Move back the checkpoint to its place\n                shutil.move(tmp_checkpoint, checkpoint_folder)\n\n    def push_to_hub(self, commit_message: Optional[str] = \"End of training\", blocking: bool = True, **kwargs) -> str:\n        \"\"\"\n        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.\n\n        Parameters:\n            commit_message (`str`, *optional*, defaults to `\"End of training\"`):\n                Message to commit while pushing.\n            blocking (`bool`, *optional*, defaults to `True`):\n                Whether the function should return only when the `git push` has finished.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments passed along to [`~Trainer.create_model_card`].\n\n        Returns:\n            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of\n            the commit and an object to track the progress of the commit if `blocking=True`\n        \"\"\"\n        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but\n        # it might fail.\n        if not hasattr(self, \"repo\"):\n            self.init_git_repo()\n\n        model_name = kwargs.pop(\"model_name\", None)\n        if model_name is None and self.args.should_save:\n            if self.args.hub_model_id is None:\n                model_name = Path(self.args.output_dir).name\n            else:\n                model_name = self.args.hub_model_id.split(\"/\")[-1]\n\n        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by\n        # self.args.should_save.\n        self.save_model(_internal_call=True)\n\n        # Only push from one node.\n        if not self.is_world_process_zero():\n            return\n\n        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.\n        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:\n            self.push_in_progress._process.kill()\n            self.push_in_progress = None\n\n        git_head_commit_url = self.repo.push_to_hub(\n            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True\n        )\n        # push separately the model card to be independant from the rest of the model\n        if self.args.should_save:\n            self.create_model_card(model_name=model_name, **kwargs)\n            try:\n                self.repo.push_to_hub(\n                    commit_message=\"update model card README.md\", blocking=blocking, auto_lfs_prune=True\n                )\n            except EnvironmentError as exc:\n                logger.error(f\"Error pushing update to the model card. Please read logs and retry.\\n${exc}\")\n\n        return git_head_commit_url\n\n    #\n    # Deprecated code\n    #\n\n    def prediction_loop(\n        self,\n        dataloader: DataLoader,\n        description: str,\n        prediction_loss_only: Optional[bool] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> EvalLoopOutput:\n        \"\"\"\n        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.\n\n        Works both with or without labels.\n        \"\"\"\n        args = self.args\n\n        if not has_length(dataloader):\n            raise ValueError(\"dataloader must implement a working __len__\")\n\n        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only\n\n        # if eval is called w/o train, handle model prep here\n        if self.is_deepspeed_enabled and self.deepspeed is None:\n            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)\n\n        model = self._wrap_model(self.model, training=False, dataloader=dataloader)\n\n        if len(self.accelerator._models) == 0 and model is self.model:\n            model = (\n                self.accelerator.prepare(model)\n                if self.is_deepspeed_enabled\n                else self.accelerator.prepare_model(model, evaluation_mode=True)\n            )\n\n            if self.is_fsdp_enabled:\n                self.model = model\n\n            # for the rest of this function `model` is the outside model, whether it was wrapped or not\n            if model is not self.model:\n                self.model_wrapped = model\n\n            # backward compatibility\n            if self.is_deepspeed_enabled:\n                self.deepspeed = self.model_wrapped\n\n        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called\n        # while ``train`` is running, cast it to the right dtype first and then put on device\n        if not self.is_in_train:\n            if args.fp16_full_eval:\n                model = model.to(dtype=torch.float16, device=args.device)\n            elif args.bf16_full_eval:\n                model = model.to(dtype=torch.bfloat16, device=args.device)\n\n        batch_size = dataloader.batch_size\n        num_examples = self.num_examples(dataloader)\n        logger.info(f\"***** Running {description} *****\")\n        logger.info(f\"  Num examples = {num_examples}\")\n        logger.info(f\"  Batch size = {batch_size}\")\n        losses_host: torch.Tensor = None\n        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None\n        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None\n        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None\n\n        world_size = max(1, args.world_size)\n\n        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)\n        if not prediction_loss_only:\n            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass\n            # a batch size to the sampler)\n            make_multiple_of = None\n            if hasattr(dataloader, \"sampler\") and isinstance(dataloader.sampler, SequentialDistributedSampler):\n                make_multiple_of = dataloader.sampler.batch_size\n            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n\n        model.eval()\n\n        if args.past_index >= 0:\n            self._past = None\n\n        self.callback_handler.eval_dataloader = dataloader\n\n        for step, inputs in enumerate(dataloader):\n            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)\n            inputs_decode = self._prepare_input(inputs[\"input_ids\"]) if args.include_inputs_for_metrics else None\n\n            if loss is not None:\n                losses = loss.repeat(batch_size)\n                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)\n            if logits is not None:\n                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)\n            if labels is not None:\n                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)\n            if inputs_decode is not None:\n                inputs_host = (\n                    inputs_decode\n                    if inputs_host is None\n                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)\n                )\n            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)\n\n            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.\n            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:\n                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, \"eval_losses\"))\n                if not prediction_loss_only:\n                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, \"eval_preds\"))\n                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, \"eval_label_ids\"))\n                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, \"eval_inputs_ids\"))\n\n                # Set back to None to begin a new accumulation\n                losses_host, preds_host, labels_host, inputs_host = None, None, None, None\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of the evaluation loop\n            delattr(self, \"_past\")\n\n        # Gather all remaining tensors and put them back on the CPU\n        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, \"eval_losses\"))\n        if not prediction_loss_only:\n            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, \"eval_preds\"))\n            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, \"eval_label_ids\"))\n            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, \"eval_inputs_ids\"))\n\n        eval_loss = eval_losses_gatherer.finalize()\n        preds = preds_gatherer.finalize() if not prediction_loss_only else None\n        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None\n        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None\n\n        if self.compute_metrics is not None and preds is not None and label_ids is not None:\n            if args.include_inputs_for_metrics:\n                metrics = self.compute_metrics(\n                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)\n                )\n            else:\n                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))\n        else:\n            metrics = {}\n\n        # To be JSON-serializable, we need to remove numpy types or zero-d tensors\n        metrics = denumpify_detensorize(metrics)\n\n        if eval_loss is not None:\n            metrics[f\"{metric_key_prefix}_loss\"] = eval_loss.mean().item()\n\n        # Prefix all keys with metric_key_prefix + '_'\n        for key in list(metrics.keys()):\n            if not key.startswith(f\"{metric_key_prefix}_\"):\n                metrics[f\"{metric_key_prefix}_{key}\"] = metrics.pop(key)\n\n        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)\n\n    def _gather_and_numpify(self, tensors, name):\n        \"\"\"\n        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before\n        concatenating them to `gathered`\n        \"\"\"\n        if tensors is None:\n            return\n        if is_torch_tpu_available():\n            tensors = nested_xla_mesh_reduce(tensors, name)\n        elif is_sagemaker_mp_enabled():\n            tensors = smp_gather(tensors)\n        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            tensors = distributed_concat(tensors)\n\n        return nested_numpify(tensors)\n\n    def _add_sm_patterns_to_gitignore(self) -> None:\n        \"\"\"Add SageMaker Checkpointing patterns to .gitignore file.\"\"\"\n        # Make sure we only do this on the main process\n        if not self.is_world_process_zero():\n            return\n\n        patterns = [\"*.sagemaker-uploading\", \"*.sagemaker-uploaded\"]\n\n        # Get current .gitignore content\n        if os.path.exists(os.path.join(self.repo.local_dir, \".gitignore\")):\n            with open(os.path.join(self.repo.local_dir, \".gitignore\"), \"r\") as f:\n                current_content = f.read()\n        else:\n            current_content = \"\"\n\n        # Add the patterns to .gitignore\n        content = current_content\n        for pattern in patterns:\n            if pattern not in content:\n                if content.endswith(\"\\n\"):\n                    content += pattern\n                else:\n                    content += f\"\\n{pattern}\"\n\n        # Write the .gitignore file if it has changed\n        if content != current_content:\n            with open(os.path.join(self.repo.local_dir, \".gitignore\"), \"w\") as f:\n                logger.debug(f\"Writing .gitignore file. Content: {content}\")\n                f.write(content)\n\n        self.repo.git_add(\".gitignore\")\n\n        # avoid race condition with git status\n        time.sleep(0.5)\n\n        if not self.repo.is_repo_clean():\n            self.repo.git_commit(\"Add *.sagemaker patterns to .gitignore.\")\n            self.repo.git_push()\n\n    def create_accelerator_and_postprocess(self):\n        grad_acc_kwargs = {\"num_steps\": self.args.gradient_accumulation_steps}\n        if version.parse(accelerate_version) > version.parse(\"0.20.3\"):\n            grad_acc_kwargs[\"sync_with_dataloader\"] = False\n        gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)\n\n        # create accelerator object\n        self.accelerator = Accelerator(\n            deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin\n        )\n\n        # deepspeed and accelerate flags covering both trainer args and accelerate launcher\n        self.is_deepspeed_enabled = getattr(self.accelerator.state, \"deepspeed_plugin\", None) is not None\n        self.is_fsdp_enabled = getattr(self.accelerator.state, \"fsdp_plugin\", None) is not None\n\n        # post accelerator creation setup\n        if self.is_fsdp_enabled:\n            fsdp_plugin = self.accelerator.state.fsdp_plugin\n            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(\n                \"limit_all_gathers\", fsdp_plugin.limit_all_gathers\n            )\n            fsdp_plugin.use_orig_params = self.args.fsdp_config.get(\"use_orig_params\", fsdp_plugin.use_orig_params)\n\n        if self.is_deepspeed_enabled:\n            if getattr(self.args, \"hf_deepspeed_config\", None) is None:\n                from transformers.deepspeed import HfTrainerDeepSpeedConfig\n\n                ds_plugin = self.accelerator.state.deepspeed_plugin\n\n                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)\n                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config\n                ds_plugin.hf_ds_config.trainer_config_process(self.args)\n\n\n\nclass LLaVATrainer(TrainerLLavaGD):\n\n    def _save_checkpoint(self, model, trial, metrics=None):\n        # if getattr(self.args, 'tune_mm_mlp_adapter', False):\n        #     from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n        #     checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n        #\n        #     run_dir = self._get_output_dir(trial=trial)\n        #     output_dir = os.path.join(run_dir, checkpoint_folder)\n        #\n        #     # Only save Adapter\n        #     keys_to_match = ['mm_projector']\n        #     if getattr(self.args, \"use_im_start_end\", False) or getattr(self.args, \"new_tokens\", False):\n        #         keys_to_match.extend(['embed_tokens', 'embed_in','lm_head'])\n        #     # import pdb; pdb.set_trace()\n        #     weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)\n        #\n        #     if self.args.local_rank == 0 or self.args.local_rank == -1:\n        #         self.model.config.save_pretrained(output_dir)\n        #         torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n        # else:\n        super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        # if getattr(self.args, 'tune_mm_mlp_adapter', False):\n        #     pass\n        # else:\n        super(LLaVATrainer, self)._save(output_dir, state_dict)\n"
  },
  {
    "path": "llava/train/llava_trainer_joint_train.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n# from transformers import Trainer\nfrom typing import Optional\nfrom transformers.trainer import *\nfrom datasets_os import build_train_dataloader\nfrom dataclasses import dataclass, field\nimport transformers\nfrom typing import Dict, Optional, Sequence, List\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                print(name, 'no ignore status')\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}\n    return to_return\n\n@dataclass\nclass DataCollatorForSupervisedDatasetEmpty(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]):\n        return instances\n        # input_ids, labels = tuple([instance[key] for instance in instances]\n        #                           for key in (\"input_ids\", \"labels\"))\n        # input_ids = torch.nn.utils.rnn.pad_sequence(\n        #     input_ids,\n        #     batch_first=True,\n        #     padding_value=self.tokenizer.pad_token_id)\n        # labels = torch.nn.utils.rnn.pad_sequence(labels,\n        #                                          batch_first=True,\n        #                                          padding_value=IGNORE_INDEX)\n        # input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        # labels = labels[:, :self.tokenizer.model_max_length]\n        # batch = dict(\n        #     input_ids=input_ids,\n        #     labels=labels,\n        #     attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        # )\n        #\n        # if 'image' in instances[0]:\n        #     images = [instance['image'] for instance in instances]\n        #     if all(x is not None and x.shape == images[0].shape for x in images):\n        #         batch['images'] = torch.stack(images)\n        #     else:\n        #         batch['images'] = images\n        #\n        # return batch\n\nclass TrainerLLavaGD(Trainer):\n    \"\"\"\n    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.\n\n    Args:\n        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):\n            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.\n\n            <Tip>\n\n            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use\n            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers\n            models.\n\n            </Tip>\n\n        args ([`TrainingArguments`], *optional*):\n            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the\n            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.\n        data_collator (`DataCollator`, *optional*):\n            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will\n            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of\n            [`DataCollatorWithPadding`] otherwise.\n        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):\n            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the\n            `model.forward()` method are automatically removed.\n\n            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a\n            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a\n            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will\n            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally\n            sets the seed of the RNGs used.\n        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):\n             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the\n             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each\n             dataset prepending the dictionary key to the metric name.\n        tokenizer ([`PreTrainedTokenizerBase`], *optional*):\n            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the\n            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an\n            interrupted training or reuse the fine-tuned model.\n        model_init (`Callable[[], PreTrainedModel]`, *optional*):\n            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start\n            from a new instance of the model as given by this function.\n\n            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to\n            be able to choose different architectures according to hyper parameters (such as layer count, sizes of\n            inner layers, dropout probabilities etc).\n        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):\n            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return\n            a dictionary string to metric values.\n        callbacks (List of [`TrainerCallback`], *optional*):\n            A list of callbacks to customize the training loop. Will add those to the list of default callbacks\n            detailed in [here](callback).\n\n            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.\n        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple\n            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model\n            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.\n        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):\n            A function that preprocess the logits right before caching them at each evaluation step. Must take two\n            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made\n            by this function will be reflected in the predictions received by `compute_metrics`.\n\n            Note that the labels (second parameter) will be `None` if the dataset does not have them.\n\n    Important attributes:\n\n        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]\n          subclass.\n        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the\n          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,\n          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner\n          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.\n        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from\n          data parallelism, this means some of the model layers are split on different GPUs).\n        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set\n          to `False` if model parallel or deepspeed is used, or if the default\n          `TrainingArguments.place_model_on_device` is overridden to return `False` .\n        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while\n          in `train`)\n\n    \"\"\"\n\n    # Those are used as methods of the Trainer in examples.\n\n    def __init__(\n        self,\n        model: Union[PreTrainedModel, nn.Module] = None,\n        args: TrainingArguments = None,\n        data_collator: Optional[DataCollator] = None,\n        train_dataset: Optional[Dataset] = None,\n        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = None,\n        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        data_loader_args=None,\n        cfg=None,\n    ):\n        self.cfg=cfg\n        if args is None:\n            output_dir = \"tmp_trainer\"\n            logger.info(f\"No `TrainingArguments` passed, using `output_dir={output_dir}`.\")\n            args = TrainingArguments(output_dir=output_dir)\n        self.args = args\n        # Seed must be set before instantiating the model when using model\n        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)\n        self.hp_name = None\n        self.deepspeed = None\n        self.is_in_train = False\n        self.data_loader_args=data_loader_args\n        self.create_accelerator_and_postprocess()\n\n        # memory metrics - must set up as early as possible\n        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)\n        self._memory_tracker.start()\n\n        # set the correct log level depending on the node\n        log_level = args.get_process_log_level()\n        logging.set_verbosity(log_level)\n\n        # force device and distributed setup init explicitly\n        args._setup_devices\n\n        if model is None:\n            if model_init is not None:\n                self.model_init = model_init\n                model = self.call_model_init()\n            else:\n                raise RuntimeError(\"`Trainer` requires either a `model` or `model_init` argument\")\n        else:\n            if model_init is not None:\n                warnings.warn(\n                    \"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will\"\n                    \" overwrite your model when calling the `train` method. This will become a fatal error in the next\"\n                    \" release.\",\n                    FutureWarning,\n                )\n            self.model_init = model_init\n\n        if model.__class__.__name__ in MODEL_MAPPING_NAMES:\n            raise ValueError(\n                f\"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only \"\n                \"computes hidden states and does not accept any labels. You should choose a model with a head \"\n                \"suitable for your task like any of the `AutoModelForXxx` listed at \"\n                \"https://huggingface.co/docs/transformers/model_doc/auto.\"\n            )\n\n        if hasattr(model, \"is_parallelizable\") and model.is_parallelizable and model.model_parallel:\n            self.is_model_parallel = True\n        else:\n            self.is_model_parallel = False\n\n        if getattr(model, \"hf_device_map\", None) is not None:\n            devices = [device for device in set(model.hf_device_map.values()) if device not in [\"cpu\", \"disk\"]]\n            if len(devices) > 1:\n                self.is_model_parallel = True\n            else:\n                self.is_model_parallel = self.args.device != torch.device(devices[0])\n\n            # warn users\n            logger.info(\n                \"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set\"\n                \" to `True` to avoid any unexpected behavior such as device placement mismatching.\"\n            )\n\n        # At this stage the model is already loaded\n        if getattr(model, \"is_quantized\", False):\n            if getattr(model, \"_is_quantized_training_enabled\", False):\n                logger.info(\n                    \"The model is loaded in 8-bit precision. To train this model you need to add additional modules\"\n                    \" inside the model such as adapters using `peft` library and freeze the model weights. Please\"\n                    \" check \"\n                    \" the examples in https://github.com/huggingface/peft for more details.\"\n                )\n            else:\n                raise ValueError(\n                    \"The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit\"\n                    \" model, please make sure that you have installed `bitsandbytes>=0.37.0`. \"\n                )\n\n        # Setup Sharded DDP training\n        self.sharded_ddp = None\n        if len(args.sharded_ddp) > 0:\n            if self.is_deepspeed_enabled:\n                raise ValueError(\n                    \"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags.\"\n                )\n            if len(args.fsdp) > 0:\n                raise ValueError(\n                    \"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags.\"\n                )\n            if args.parallel_mode != ParallelMode.DISTRIBUTED:\n                raise ValueError(\"Using sharded DDP only works in distributed training.\")\n            elif not is_fairscale_available():\n                raise ImportError(\"Sharded DDP training requires fairscale: `pip install fairscale`.\")\n            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:\n                raise ImportError(\n                    \"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found \"\n                    f\"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`.\"\n                )\n            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.SIMPLE\n            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2\n            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:\n                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3\n\n        self.fsdp = None\n        if len(args.fsdp) > 0:\n            if self.is_deepspeed_enabled:\n                raise ValueError(\n                    \"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags.\"\n                )\n            if not args.fsdp_config[\"xla\"] and args.parallel_mode != ParallelMode.DISTRIBUTED:\n                raise ValueError(\"Using fsdp only works in distributed training.\")\n\n            # dep_version_check(\"torch>=1.12.0\")\n            # Would have to update setup.py with torch>=1.12.0\n            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0\n            # below is the current alternative.\n            if version.parse(version.parse(torch.__version__).base_version) < version.parse(\"1.12.0\"):\n                raise ValueError(\"FSDP requires PyTorch >= 1.12.0\")\n\n            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy\n\n            if FSDPOption.FULL_SHARD in args.fsdp:\n                self.fsdp = ShardingStrategy.FULL_SHARD\n            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:\n                self.fsdp = ShardingStrategy.SHARD_GRAD_OP\n            elif FSDPOption.NO_SHARD in args.fsdp:\n                self.fsdp = ShardingStrategy.NO_SHARD\n\n            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE\n            if \"backward_prefetch\" in self.args.fsdp_config and \"backward_post\" in self.args.fsdp_config.get(\n                \"backward_prefetch\", []\n            ):\n                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST\n\n            self.forward_prefetch = False\n            if self.args.fsdp_config.get(\"forward_prefect\", False):\n                self.forward_prefetch = True\n\n            self.limit_all_gathers = False\n            if self.args.fsdp_config.get(\"limit_all_gathers\", False):\n                self.limit_all_gathers = True\n\n        # one place to sort out whether to place the model on device or not\n        # postpone switching model to cuda when:\n        # 1. MP - since we are trying to fit a much bigger than 1 gpu model\n        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,\n        #    and we only use deepspeed for training at the moment\n        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first\n        # 4. Sharded DDP - same as MP\n        # 5. FSDP - same as MP\n        self.place_model_on_device = args.place_model_on_device\n        if (\n            self.is_model_parallel\n            or self.is_deepspeed_enabled\n            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)\n            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])\n            or (self.fsdp is not None)\n            or self.is_fsdp_enabled\n        ):\n            self.place_model_on_device = False\n\n        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)\n        self.data_collator = data_collator if data_collator is not None else default_collator\n        self.train_dataset = train_dataset\n        self.eval_dataset = eval_dataset\n        self.tokenizer = tokenizer\n\n        # Quantized models doesn't support `.to` operation.\n        if self.place_model_on_device and not getattr(model, \"is_quantized\", False):\n            self._move_model_to_device(model, args.device)\n\n        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs\n        if self.is_model_parallel:\n            self.args._n_gpu = 1\n\n        # later use `self.model is self.model_wrapped` to check if it's wrapped or not\n        self.model_wrapped = model\n        self.model = model\n\n        self.compute_metrics = compute_metrics\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n        self.optimizer, self.lr_scheduler = optimizers\n        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):\n            raise RuntimeError(\n                \"Passing a `model_init` is incompatible with providing the `optimizers` argument. \"\n                \"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method.\"\n            )\n        if is_torch_tpu_available() and self.optimizer is not None:\n            for param in self.model.parameters():\n                model_device = param.device\n                break\n            for param_group in self.optimizer.param_groups:\n                if len(param_group[\"params\"]) > 0:\n                    optimizer_device = param_group[\"params\"][0].device\n                    break\n            if model_device != optimizer_device:\n                raise ValueError(\n                    \"The model and the optimizer parameters are not on the same device, which probably means you\"\n                    \" created an optimizer around your model **before** putting on the device and passing it to the\"\n                    \" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and\"\n                    \" `model.to(xm.xla_device())` is performed before the optimizer creation in your script.\"\n                )\n        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (\n            self.optimizer is not None or self.lr_scheduler is not None\n        ):\n            raise RuntimeError(\n                \"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled.\"\n                \"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method.\"\n            )\n        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)\n        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks\n        self.callback_handler = CallbackHandler(\n            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler\n        )\n        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)\n\n        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.\n        self._loggers_initialized = False\n\n        # Create clone of distant repo and output directory if needed\n        if self.args.push_to_hub:\n            self.init_git_repo(at_init=True)\n            # In case of pull, we need to make sure every process has the latest.\n            if is_torch_tpu_available():\n                xm.rendezvous(\"init git repo\")\n            elif args.parallel_mode == ParallelMode.DISTRIBUTED:\n                dist.barrier()\n\n        if self.args.should_save:\n            os.makedirs(self.args.output_dir, exist_ok=True)\n\n        if not callable(self.data_collator) and callable(getattr(self.data_collator, \"collate_batch\", None)):\n            raise ValueError(\"The `data_collator` should be a simple callable (function, class with `__call__`).\")\n\n        if args.max_steps > 0:\n            logger.info(\"max_steps is given, it will override any value given in num_train_epochs\")\n\n        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:\n            raise ValueError(\n                \"The train_dataset does not implement __len__, max_steps has to be specified. \"\n                \"The number of steps needs to be known in advance for the learning rate scheduler.\"\n            )\n\n        if (\n            train_dataset is not None\n            and isinstance(train_dataset, torch.utils.data.IterableDataset)\n            and args.group_by_length\n        ):\n            raise ValueError(\"the `--group_by_length` option is only available for `Dataset`, not `IterableDataset\")\n\n        self._signature_columns = None\n\n        # Mixed precision setup\n        self.use_apex = False\n        self.use_cuda_amp = False\n        self.use_cpu_amp = False\n\n        # Mixed precision setup for SageMaker Model Parallel\n        if is_sagemaker_mp_enabled():\n            # BF16 + model parallelism in SageMaker: currently not supported, raise an error\n            if args.bf16:\n                raise ValueError(\"SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead \")\n\n            if IS_SAGEMAKER_MP_POST_1_10:\n                # When there's mismatch between SMP config and trainer argument, use SMP config as truth\n                if args.fp16 != smp.state.cfg.fp16:\n                    logger.warning(\n                        f\"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},\"\n                        f\"but FP16 provided in trainer argument is {args.fp16},\"\n                        f\"setting to {smp.state.cfg.fp16}\"\n                    )\n                    args.fp16 = smp.state.cfg.fp16\n            else:\n                # smp < 1.10 does not support fp16 in trainer.\n                if hasattr(smp.state.cfg, \"fp16\"):\n                    logger.warning(\n                        f\"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, \"\n                        \"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer.\"\n                    )\n\n        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:\n            if args.half_precision_backend == \"auto\":\n                if args.device == torch.device(\"cpu\"):\n                    if args.fp16:\n                        raise ValueError(\"Tried to use `fp16` but it is not supported on cpu\")\n                    else:\n                        args.half_precision_backend = \"cpu_amp\"\n                else:\n                    args.half_precision_backend = \"cuda_amp\"\n\n            logger.info(f\"Using {args.half_precision_backend} half precision backend\")\n\n        self.do_grad_scaling = False\n        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):\n            # deepspeed and SageMaker Model Parallel manage their own half precision\n            if self.sharded_ddp is not None:\n                if args.half_precision_backend == \"cuda_amp\":\n                    self.use_cuda_amp = True\n                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16\n                    #  bf16 does not need grad scaling\n                    self.do_grad_scaling = self.amp_dtype == torch.float16\n                    if self.do_grad_scaling:\n                        if self.sharded_ddp is not None:\n                            self.scaler = ShardedGradScaler()\n                        elif self.fsdp is not None:\n                            from torch.distributed.fsdp.sharded_grad_scaler import (\n                                ShardedGradScaler as FSDPShardedGradScaler,\n                            )\n\n                            self.scaler = FSDPShardedGradScaler()\n                        elif is_torch_tpu_available():\n                            from torch_xla.amp import GradScaler\n\n                            self.scaler = GradScaler()\n                        else:\n                            self.scaler = torch.cuda.amp.GradScaler()\n                elif args.half_precision_backend == \"cpu_amp\":\n                    self.use_cpu_amp = True\n                    self.amp_dtype = torch.bfloat16\n            elif args.half_precision_backend == \"apex\":\n                if not is_apex_available():\n                    raise ImportError(\n                        \"Using FP16 with APEX but APEX is not installed, please refer to\"\n                        \" https://www.github.com/nvidia/apex.\"\n                    )\n                self.use_apex = True\n\n        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.\n        if (\n            is_sagemaker_mp_enabled()\n            and self.use_cuda_amp\n            and args.max_grad_norm is not None\n            and args.max_grad_norm > 0\n        ):\n            raise ValueError(\n                \"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass \"\n                \"along 'max_grad_norm': 0 in your hyperparameters.\"\n            )\n\n        # Label smoothing\n        if self.args.label_smoothing_factor != 0:\n            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)\n        else:\n            self.label_smoother = None\n\n        self.state = TrainerState(\n            is_local_process_zero=self.is_local_process_zero(),\n            is_world_process_zero=self.is_world_process_zero(),\n        )\n\n        self.control = TrainerControl()\n        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then\n        # returned to 0 every time flos need to be logged\n        self.current_flos = 0\n        self.hp_search_backend = None\n        self.use_tune_checkpoints = False\n        default_label_names = find_labels(self.model.__class__)\n        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names\n        self.can_return_loss = can_return_loss(self.model.__class__)\n        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)\n\n        # Internal variables to help with automatic batch size reduction\n        self._train_batch_size = args.train_batch_size\n        self._created_lr_scheduler = False\n\n        # very last\n        self._memory_tracker.stop_and_update_metrics()\n\n        # torch.compile\n        if args.torch_compile and not is_torch_compile_available():\n            raise RuntimeError(\"Using torch.compile requires PyTorch 2.0 or higher.\")\n\n    def add_callback(self, callback):\n        \"\"\"\n        Add a callback to the current list of [`~transformer.TrainerCallback`].\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will instantiate a member of that class.\n        \"\"\"\n        self.callback_handler.add_callback(callback)\n\n    def pop_callback(self, callback):\n        \"\"\"\n        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.\n\n        If the callback is not found, returns `None` (and no error is raised).\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will pop the first member of that class found in the list of callbacks.\n\n        Returns:\n            [`~transformer.TrainerCallback`]: The callback removed, if found.\n        \"\"\"\n        return self.callback_handler.pop_callback(callback)\n\n    def remove_callback(self, callback):\n        \"\"\"\n        Remove a callback from the current list of [`~transformer.TrainerCallback`].\n\n        Args:\n           callback (`type` or [`~transformer.TrainerCallback`]):\n               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the\n               first case, will remove the first member of that class found in the list of callbacks.\n        \"\"\"\n        self.callback_handler.remove_callback(callback)\n\n    def _move_model_to_device(self, model, device):\n        model = model.to(device)\n        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.\n        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n\n    def _set_signature_columns_if_needed(self):\n        if self._signature_columns is None:\n            # Inspect model forward signature to keep only the arguments it accepts.\n            signature = inspect.signature(self.model.forward)\n            self._signature_columns = list(signature.parameters.keys())\n            # Labels may be named label or label_ids, the default data collator handles that.\n            self._signature_columns += list(set([\"label\", \"label_ids\"] + self.label_names))\n\n    def _remove_unused_columns(self, dataset: \"datasets.Dataset\", description: Optional[str] = None):\n        if not self.args.remove_unused_columns:\n            return dataset\n        self._set_signature_columns_if_needed()\n        signature_columns = self._signature_columns\n\n        ignored_columns = list(set(dataset.column_names) - set(signature_columns))\n        if len(ignored_columns) > 0:\n            dset_description = \"\" if description is None else f\"in the {description} set\"\n            logger.info(\n                f\"The following columns {dset_description} don't have a corresponding argument in \"\n                f\"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}.\"\n                f\" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, \"\n                \" you can safely ignore this message.\"\n            )\n\n        columns = [k for k in signature_columns if k in dataset.column_names]\n\n        if version.parse(datasets.__version__) < version.parse(\"1.4.0\"):\n            dataset.set_format(\n                type=dataset.format[\"type\"], columns=columns, format_kwargs=dataset.format[\"format_kwargs\"]\n            )\n            return dataset\n        else:\n            return dataset.remove_columns(ignored_columns)\n\n    def _get_collator_with_removed_columns(\n        self, data_collator: Callable, description: Optional[str] = None\n    ) -> Callable:\n        \"\"\"Wrap the data collator in a callable removing unused columns.\"\"\"\n        if not self.args.remove_unused_columns:\n            return data_collator\n        self._set_signature_columns_if_needed()\n        signature_columns = self._signature_columns\n\n        remove_columns_collator = RemoveColumnsCollator(\n            data_collator=data_collator,\n            signature_columns=signature_columns,\n            logger=logger,\n            description=description,\n            model_name=self.model.__class__.__name__,\n        )\n        return remove_columns_collator\n\n    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:\n        if self.train_dataset is None or not has_length(self.train_dataset):\n            return None\n\n        # Build the sampler.\n        if self.args.group_by_length:\n            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):\n                lengths = (\n                    self.train_dataset[self.args.length_column_name]\n                    if self.args.length_column_name in self.train_dataset.column_names\n                    else None\n                )\n            else:\n                lengths = None\n            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None\n            return LengthGroupedSampler(\n                self.args.train_batch_size * self.args.gradient_accumulation_steps,\n                dataset=self.train_dataset,\n                lengths=lengths,\n                model_input_name=model_input_name,\n            )\n\n        else:\n            return RandomSampler(self.train_dataset)\n\n    def get_train_dataloader(self) -> DataLoader:\n        \"\"\"\n        Returns the training [`~torch.utils.data.DataLoader`].\n\n        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed\n        training if necessary) otherwise.\n\n        Subclass and override this method if you want to inject some custom behavior.\n        \"\"\"\n        if self.train_dataset is None:\n            raise ValueError(\"Trainer: training requires a train_dataset.\")\n\n        train_dataset = self.train_dataset\n        # data_collator = self.data_collator\n        data_collator = DataCollatorForSupervisedDatasetEmpty(tokenizer=self.tokenizer)\n        # datacolator=\n        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):\n            train_dataset = self._remove_unused_columns(train_dataset, description=\"training\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"training\")\n\n        dataloader_params = {\n            \"batch_size\": self._train_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(train_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_train_sampler()\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n            dataloader_params[\"worker_init_fn\"] = seed_worker\n\n        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))\n\n    def get_train_dataloaderd2(self) -> DataLoader:\n        llava_cap_loader=self.get_train_dataloader()\n        return build_train_dataloader(self.cfg,tokenizer=self.data_loader_args[0],data_args=self.data_loader_args[1],preprocess=self.data_loader_args[2],llava_cap_loader=llava_cap_loader)\n\n    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:\n        # Deprecated code\n        if self.args.use_legacy_prediction_loop:\n            if is_torch_tpu_available():\n                return SequentialDistributedSampler(\n                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()\n                )\n            elif is_sagemaker_mp_enabled():\n                return SequentialDistributedSampler(\n                    eval_dataset,\n                    num_replicas=smp.dp_size(),\n                    rank=smp.dp_rank(),\n                    batch_size=self.args.per_device_eval_batch_size,\n                )\n            else:\n                return SequentialSampler(eval_dataset)\n\n        if self.args.world_size <= 1:\n            return SequentialSampler(eval_dataset)\n        else:\n            return None\n\n    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:\n        \"\"\"\n        Returns the evaluation [`~torch.utils.data.DataLoader`].\n\n        Subclass and override this method if you want to inject some custom behavior.\n\n        Args:\n            eval_dataset (`torch.utils.data.Dataset`, *optional*):\n                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted\n                by the `model.forward()` method are automatically removed. It must implement `__len__`.\n        \"\"\"\n        if eval_dataset is None and self.eval_dataset is None:\n            raise ValueError(\"Trainer: evaluation requires an eval_dataset.\")\n        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset\n        data_collator = self.data_collator\n\n        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):\n            eval_dataset = self._remove_unused_columns(eval_dataset, description=\"evaluation\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"evaluation\")\n\n        dataloader_params = {\n            \"batch_size\": self.args.eval_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_eval_sampler(eval_dataset)\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))\n\n    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:\n        \"\"\"\n        Returns the test [`~torch.utils.data.DataLoader`].\n\n        Subclass and override this method if you want to inject some custom behavior.\n\n        Args:\n            test_dataset (`torch.utils.data.Dataset`, *optional*):\n                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the\n                `model.forward()` method are automatically removed. It must implement `__len__`.\n        \"\"\"\n        data_collator = self.data_collator\n\n        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):\n            test_dataset = self._remove_unused_columns(test_dataset, description=\"test\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"test\")\n\n        dataloader_params = {\n            \"batch_size\": self.args.eval_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        if not isinstance(test_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_eval_sampler(test_dataset)\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n        # We use the same batch_size as for eval.\n        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))\n\n    def create_optimizer_and_scheduler(self, num_training_steps: int):\n        \"\"\"\n        Setup the optimizer and the learning rate scheduler.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or\n        `create_scheduler`) in a subclass.\n        \"\"\"\n        self.create_optimizer()\n        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:\n            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer\n            optimizer = self.optimizer.optimizer\n        else:\n            optimizer = self.optimizer\n        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model\n\n        if self.optimizer is None:\n            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            # optimizer_grouped_parameters = [\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": self.args.weight_decay,\n            #     },\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": 0.0,\n            #     },\n            # ]\n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n\n            def match_name_keywords(n, name_keywords):\n                out = False\n                for b in name_keywords:\n                    if b in n:\n                        out = True\n                        break\n                return out\n\n            lr_backbone_names=['backbone']\n            lr_linear_proj_names=['reference_points', 'sampling_offsets']\n            seg_model_names=['seg_model']\n            optimizer_grouped_parameters = [\n                {\n                    \"params\":\n                        [p for n, p in opt_model.named_parameters()\n                         if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names))\n                         and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names))\n                         and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr'],\n                },\n                {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names) and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr']*0.1,\n                },\n                {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names) and p.requires_grad],\n                    \"lr\": optimizer_kwargs['lr']*0.1,\n                },\n\n            ]\n            if not getattr(self.args, 'tune_mm_mlp_adapter', False):\n                optimizer_grouped_parameters[0] = {\n                        \"params\":\n                            [p for n, p in opt_model.named_parameters()\n                             if\n                             not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names))\n                             and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n,\n                                                                                                           seg_model_names))\n                             and match_name_keywords(n,seg_model_names)\n                             and p.requires_grad],\n                        \"lr\": optimizer_kwargs['lr'],\n                    }\n                llm_dict= {\n                    \"params\": [p for n, p in opt_model.named_parameters()\n                               if n.startswith('model.') and p.requires_grad],\n                    \"lr\": 2e-5,\n                }\n                optimizer_grouped_parameters.append(llm_dict)\n            if getattr(self.args, 'train_interactive', False):\n                interactive_model_names=['interactive_model']\n                optimizer_grouped_parameters_inter = [\n                    {\n                        \"params\":\n                            [p for n, p in opt_model.named_parameters()\n                             if\n                             match_name_keywords(n, interactive_model_names) and\n                             not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, interactive_model_names))\n                             and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n,\n                                                                                                           interactive_model_names))\n                             and p.requires_grad],\n                        \"lr\": optimizer_kwargs['lr'],\n                    },\n                    {\n                        \"params\": [p for n, p in opt_model.named_parameters()\n                                   if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n,\n                                                                                                        interactive_model_names) and p.requires_grad],\n                        \"lr\": optimizer_kwargs['lr'] * 0.1,\n                    },\n                    {\n                        \"params\": [p for n, p in opt_model.named_parameters()\n                                   if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n,\n                                                                                                           interactive_model_names) and p.requires_grad],\n                        \"lr\": optimizer_kwargs['lr'] * 0.1,\n                    },\n\n                ]\n                optimizer_grouped_parameters.extend(optimizer_grouped_parameters_inter)\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                self.optimizer = OSS(\n                    params=optimizer_grouped_parameters,\n                    optim=optimizer_cls,\n                    **optimizer_kwargs,\n                )\n            else:\n                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n                if optimizer_cls.__name__ == \"Adam8bit\":\n                    import bitsandbytes\n\n                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()\n\n                    skipped = 0\n                    for module in opt_model.modules():\n                        if isinstance(module, nn.Embedding):\n                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())\n                            logger.info(f\"skipped {module}: {skipped/2**20}M params\")\n                            manager.register_module_override(module, \"weight\", {\"optim_bits\": 32})\n                            logger.debug(f\"bitsandbytes: will optimize {module} in fp32\")\n                    logger.info(f\"skipped: {skipped/2**20}M params\")\n\n        if is_sagemaker_mp_enabled():\n            self.optimizer = smp.DistributedOptimizer(self.optimizer)\n\n        return self.optimizer\n\n    @staticmethod\n    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:\n        \"\"\"\n        Returns the optimizer class and optimizer parameters based on the training arguments.\n\n        Args:\n            args (`transformers.training_args.TrainingArguments`):\n                The training arguments for the training session.\n\n        \"\"\"\n\n        # parse args.optim_args\n        optim_args = {}\n        if args.optim_args:\n            for mapping in args.optim_args.replace(\" \", \"\").split(\",\"):\n                key, value = mapping.split(\"=\")\n                optim_args[key] = value\n\n        optimizer_kwargs = {\"lr\": args.learning_rate}\n\n        adam_kwargs = {\n            \"betas\": (args.adam_beta1, args.adam_beta2),\n            \"eps\": args.adam_epsilon,\n        }\n        if args.optim == OptimizerNames.ADAFACTOR:\n            optimizer_cls = Adafactor\n            optimizer_kwargs.update({\"scale_parameter\": False, \"relative_step\": False})\n        elif args.optim == OptimizerNames.ADAMW_HF:\n            from .optimization import AdamW\n\n            optimizer_cls = AdamW\n            optimizer_kwargs.update(adam_kwargs)\n        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:\n            from torch.optim import AdamW\n\n            optimizer_cls = AdamW\n            optimizer_kwargs.update(adam_kwargs)\n            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:\n                optimizer_kwargs.update({\"fused\": True})\n        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:\n            try:\n                from torch_xla.amp.syncfree import AdamW\n\n                optimizer_cls = AdamW\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer failed to import syncfree AdamW from torch_xla.\")\n        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:\n            try:\n                from apex.optimizers import FusedAdam\n\n                optimizer_cls = FusedAdam\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate apex FusedAdam but apex is not installed!\")\n        elif args.optim in [\n            OptimizerNames.ADAMW_BNB,\n            OptimizerNames.ADAMW_8BIT,\n            OptimizerNames.PAGED_ADAMW,\n            OptimizerNames.PAGED_ADAMW_8BIT,\n            OptimizerNames.LION,\n            OptimizerNames.LION_8BIT,\n            OptimizerNames.PAGED_LION,\n            OptimizerNames.PAGED_LION_8BIT,\n        ]:\n            try:\n                from bitsandbytes.optim import AdamW, Lion\n\n                is_paged = False\n                optim_bits = 32\n                optimizer_cls = None\n                additional_optim_kwargs = adam_kwargs\n                if \"paged\" in args.optim:\n                    is_paged = True\n                if \"8bit\" in args.optim:\n                    optim_bits = 8\n                if \"adam\" in args.optim:\n                    optimizer_cls = AdamW\n                elif \"lion\" in args.optim:\n                    optimizer_cls = Lion\n                    additional_optim_kwargs = {\"betas\": (args.adam_beta1, args.adam_beta2)}\n\n                bnb_kwargs = {\"is_paged\": is_paged, \"optim_bits\": optim_bits}\n                optimizer_kwargs.update(additional_optim_kwargs)\n                optimizer_kwargs.update(bnb_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate bnb optimizer but bnb is not installed!\")\n        elif args.optim == OptimizerNames.ADAMW_BNB:\n            try:\n                from bitsandbytes.optim import Adam8bit\n\n                optimizer_cls = Adam8bit\n                optimizer_kwargs.update(adam_kwargs)\n            except ImportError:\n                raise ValueError(\"Trainer tried to instantiate bnb Adam8bit but bnb is not installed!\")\n        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:\n            try:\n                from torchdistx.optimizers import AnyPrecisionAdamW\n\n                optimizer_cls = AnyPrecisionAdamW\n                optimizer_kwargs.update(adam_kwargs)\n\n                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.\n                optimizer_kwargs.update(\n                    {\n                        \"use_kahan_summation\": strtobool(optim_args.get(\"use_kahan_summation\", \"False\")),\n                        \"momentum_dtype\": getattr(torch, optim_args.get(\"momentum_dtype\", \"float32\")),\n                        \"variance_dtype\": getattr(torch, optim_args.get(\"variance_dtype\", \"float32\")),\n                        \"compensation_buffer_dtype\": getattr(\n                            torch, optim_args.get(\"compensation_buffer_dtype\", \"bfloat16\")\n                        ),\n                    }\n                )\n            except ImportError:\n                raise ValueError(\"Please install https://github.com/pytorch/torchdistx\")\n        elif args.optim == OptimizerNames.SGD:\n            optimizer_cls = torch.optim.SGD\n        elif args.optim == OptimizerNames.ADAGRAD:\n            optimizer_cls = torch.optim.Adagrad\n        else:\n            raise ValueError(f\"Trainer cannot instantiate unsupported optimizer: {args.optim}\")\n        return optimizer_cls, optimizer_kwargs\n\n    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):\n        \"\"\"\n        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or\n        passed as an argument.\n\n        Args:\n            num_training_steps (int): The number of training steps to do.\n        \"\"\"\n        if self.lr_scheduler is None:\n            self.lr_scheduler = get_scheduler(\n                self.args.lr_scheduler_type,\n                optimizer=self.optimizer if optimizer is None else optimizer,\n                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                num_training_steps=num_training_steps,\n            )\n            self._created_lr_scheduler = True\n        return self.lr_scheduler\n\n    def num_examples(self, dataloader: DataLoader) -> int:\n        \"\"\"\n        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When\n        dataloader.dataset does not exist or has no length, estimates as best it can\n        \"\"\"\n        try:\n            dataset = dataloader.dataset\n            # Special case for IterableDatasetShard, we need to dig deeper\n            if isinstance(dataset, IterableDatasetShard):\n                return len(dataloader.dataset.dataset)\n            return len(dataloader.dataset)\n        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader\n            return len(dataloader) * self.args.per_device_train_batch_size\n\n    def _hp_search_setup(self, trial: Union[\"optuna.Trial\", Dict[str, Any]]):\n        \"\"\"HP search setup code\"\"\"\n        self._trial = trial\n\n        if self.hp_search_backend is None or trial is None:\n            return\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            params = self.hp_space(trial)\n        elif self.hp_search_backend == HPSearchBackend.RAY:\n            params = trial\n            params.pop(\"wandb\", None)\n        elif self.hp_search_backend == HPSearchBackend.SIGOPT:\n            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}\n        elif self.hp_search_backend == HPSearchBackend.WANDB:\n            params = trial\n\n        for key, value in params.items():\n            if not hasattr(self.args, key):\n                logger.warning(\n                    f\"Trying to set {key} in the hyperparameter search but there is no corresponding field in\"\n                    \" `TrainingArguments`.\"\n                )\n                continue\n            old_attr = getattr(self.args, key, None)\n            # Casting value to the proper type\n            if old_attr is not None:\n                value = type(old_attr)(value)\n            setattr(self.args, key, value)\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            logger.info(f\"Trial: {trial.params}\")\n        if self.hp_search_backend == HPSearchBackend.SIGOPT:\n            logger.info(f\"SigOpt Assignments: {trial.assignments}\")\n        if self.hp_search_backend == HPSearchBackend.WANDB:\n            logger.info(f\"W&B Sweep parameters: {trial}\")\n        if self.is_deepspeed_enabled:\n            if self.args.deepspeed is None:\n                raise ValueError(\"For sweeps with deepspeed, `args.deepspeed` must be set\")\n            # Rebuild the deepspeed config to reflect the updated training parameters\n            from accelerate.utils import DeepSpeedPlugin\n\n            from transformers.deepspeed import HfTrainerDeepSpeedConfig\n\n            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)\n            self.args.hf_deepspeed_config.trainer_config_process(self.args)\n            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)\n        self.create_accelerator_and_postprocess()\n\n    def _report_to_hp_search(self, trial: Union[\"optuna.Trial\", Dict[str, Any]], step: int, metrics: Dict[str, float]):\n        if self.hp_search_backend is None or trial is None:\n            return\n        self.objective = self.compute_objective(metrics.copy())\n        if self.hp_search_backend == HPSearchBackend.OPTUNA:\n            import optuna\n\n            trial.report(self.objective, step)\n            if trial.should_prune():\n                self.callback_handler.on_train_end(self.args, self.state, self.control)\n                raise optuna.TrialPruned()\n        elif self.hp_search_backend == HPSearchBackend.RAY:\n            from ray import tune\n\n            if self.control.should_save:\n                self._tune_save_checkpoint()\n            tune.report(objective=self.objective, **metrics)\n\n    def _tune_save_checkpoint(self):\n        from ray import tune\n\n        if not self.use_tune_checkpoints:\n            return\n        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:\n            output_dir = os.path.join(checkpoint_dir, f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\")\n            self.save_model(output_dir, _internal_call=True)\n            if self.args.should_save:\n                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))\n                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n\n    def call_model_init(self, trial=None):\n        model_init_argcount = number_of_arguments(self.model_init)\n        if model_init_argcount == 0:\n            model = self.model_init()\n        elif model_init_argcount == 1:\n            model = self.model_init(trial)\n        else:\n            raise RuntimeError(\"model_init should have 0 or 1 argument.\")\n\n        if model is None:\n            raise RuntimeError(\"model_init should not return None.\")\n\n        return model\n\n    def torch_jit_model_eval(self, model, dataloader, training=False):\n        if not training:\n            if dataloader is None:\n                logger.warning(\"failed to use PyTorch jit mode due to current dataloader is none.\")\n                return model\n            example_batch = next(iter(dataloader))\n            example_batch = self._prepare_inputs(example_batch)\n            try:\n                jit_model = copy.copy(model)\n                jit_model.eval()\n                original_forward = jit_model.__dict__.pop(\"_original_forward\", None)\n                # remove mixed precision hooks from the model\n                if original_forward:\n                    jit_model.forward = original_forward\n                with self.accelerator.autocast(cache_enabled=False), torch.no_grad():\n                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse(\"2.0.0\"):\n                        if isinstance(example_batch, dict):\n                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)\n                        else:\n                            jit_model = torch.jit.trace(\n                                jit_model,\n                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},\n                                strict=False,\n                            )\n                    else:\n                        jit_inputs = []\n                        for key in example_batch:\n                            example_tensor = torch.ones_like(example_batch[key])\n                            jit_inputs.append(example_tensor)\n                        jit_inputs = tuple(jit_inputs)\n                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)\n                jit_model = torch.jit.freeze(jit_model)\n                with torch.no_grad():\n                    jit_model(**example_batch)\n                    jit_model(**example_batch)\n                model = jit_model\n                self.use_cpu_amp = False\n                self.use_cuda_amp = False\n            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:\n                logger.warning(f\"failed to use PyTorch jit mode due to: {e}.\")\n\n        return model\n\n    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):\n        if not is_ipex_available():\n            raise ImportError(\n                \"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer\"\n                \" to https://github.com/intel/intel-extension-for-pytorch.\"\n            )\n\n        import intel_extension_for_pytorch as ipex\n\n        if not training:\n            model.eval()\n            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype\n            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings\n            model = ipex.optimize(model, dtype=dtype, level=\"O1\", conv_bn_folding=False, inplace=not self.is_in_train)\n        else:\n            if not model.training:\n                model.train()\n            model, self.optimizer = ipex.optimize(\n                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level=\"O1\"\n            )\n\n        return model\n\n    def _wrap_model(self, model, training=True, dataloader=None):\n        if self.args.use_ipex:\n            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32\n            model = self.ipex_optimize_model(model, training, dtype=dtype)\n\n        if is_sagemaker_mp_enabled():\n            # Wrapping the base model twice in a DistributedModel will raise an error.\n            if isinstance(self.model_wrapped, smp.model.DistributedModel):\n                return self.model_wrapped\n            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)\n\n        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again\n        if unwrap_model(model) is not model:\n            return model\n\n        # Mixed precision training with apex (torch < 1.6)\n        if self.use_apex and training:\n            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)\n\n        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP\n        if self.args.n_gpu > 1 and not getattr(model, \"is_loaded_in_8bit\", False):\n            model = nn.DataParallel(model)\n\n        if self.args.jit_mode_eval:\n            start_time = time.time()\n            model = self.torch_jit_model_eval(model, dataloader, training)\n            self.jit_compilation_time = round(time.time() - start_time, 4)\n\n        # Note: in torch.distributed mode, there's no point in wrapping the model\n        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.\n        if not training:\n            return model\n\n        # Distributed training (should be after apex fp16 initialization)\n        if self.sharded_ddp is not None:\n            # Sharded DDP!\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                model = ShardedDDP(model, self.optimizer)\n            else:\n                mixed_precision = self.args.fp16 or self.args.bf16\n                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp\n                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3\n                # XXX: Breaking the self.model convention but I see no way around it for now.\n                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:\n                    model = auto_wrap(model)\n                self.model = model = FullyShardedDDP(\n                    model,\n                    mixed_precision=mixed_precision,\n                    reshard_after_forward=zero_3,\n                    cpu_offload=cpu_offload,\n                ).to(self.args.device)\n        # Distributed training using PyTorch FSDP\n        elif self.fsdp is not None and self.args.fsdp_config[\"xla\"]:\n            try:\n                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP\n                from torch_xla.distributed.fsdp import checkpoint_module\n                from torch_xla.distributed.fsdp.wrap import (\n                    size_based_auto_wrap_policy,\n                    transformer_auto_wrap_policy,\n                )\n            except ImportError:\n                raise ImportError(\"Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.\")\n            auto_wrap_policy = None\n            auto_wrapper_callable = None\n            if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                auto_wrap_policy = functools.partial(\n                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                )\n            elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                transformer_cls_to_wrap = set()\n                for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                    transformer_cls = get_module_class_from_name(model, layer_class)\n                    if transformer_cls is None:\n                        raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                    else:\n                        transformer_cls_to_wrap.add(transformer_cls)\n                auto_wrap_policy = functools.partial(\n                    transformer_auto_wrap_policy,\n                    # Transformer layer class to wrap\n                    transformer_layer_cls=transformer_cls_to_wrap,\n                )\n            fsdp_kwargs = self.args.xla_fsdp_config\n            if self.args.fsdp_config[\"xla_fsdp_grad_ckpt\"]:\n                # Apply gradient checkpointing to auto-wrapped sub-modules if specified\n                def auto_wrapper_callable(m, *args, **kwargs):\n                    return FSDP(checkpoint_module(m), *args, **kwargs)\n\n            # Wrap the base model with an outer FSDP wrapper\n            self.model = model = FSDP(\n                model,\n                auto_wrap_policy=auto_wrap_policy,\n                auto_wrapper_callable=auto_wrapper_callable,\n                **fsdp_kwargs,\n            )\n\n            # Patch `xm.optimizer_step` should not reduce gradients in this case,\n            # as FSDP does not need gradient reduction over sharded parameters.\n            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):\n                loss = optimizer.step(**optimizer_args)\n                if barrier:\n                    xm.mark_step()\n                return loss\n\n            xm.optimizer_step = patched_optimizer_step\n        elif is_sagemaker_dp_enabled():\n            model = nn.parallel.DistributedDataParallel(\n                model, device_ids=[int(os.getenv(\"SMDATAPARALLEL_LOCAL_RANK\"))]\n            )\n        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            if is_torch_neuroncore_available():\n                return model\n            kwargs = {}\n            if self.args.ddp_find_unused_parameters is not None:\n                kwargs[\"find_unused_parameters\"] = self.args.ddp_find_unused_parameters\n            elif isinstance(model, PreTrainedModel):\n                # find_unused_parameters breaks checkpointing as per\n                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021\n                kwargs[\"find_unused_parameters\"] = not model.is_gradient_checkpointing\n            else:\n                kwargs[\"find_unused_parameters\"] = True\n\n            if self.args.ddp_bucket_cap_mb is not None:\n                kwargs[\"bucket_cap_mb\"] = self.args.ddp_bucket_cap_mb\n\n            if self.args.ddp_broadcast_buffers is not None:\n                kwargs[\"broadcast_buffers\"] = self.args.ddp_broadcast_buffers\n\n            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)\n\n        return model\n\n    def train(\n        self,\n        resume_from_checkpoint: Optional[Union[str, bool]] = None,\n        trial: Union[\"optuna.Trial\", Dict[str, Any]] = None,\n        ignore_keys_for_eval: Optional[List[str]] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Main training entry point.\n\n        Args:\n            resume_from_checkpoint (`str` or `bool`, *optional*):\n                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a\n                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance\n                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.\n            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):\n                The trial run or the hyperparameter dictionary for hyperparameter search.\n            ignore_keys_for_eval (`List[str]`, *optional*)\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions for evaluation during the training.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments used to hide deprecated arguments\n        \"\"\"\n        if resume_from_checkpoint is False:\n            resume_from_checkpoint = None\n\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        args = self.args\n\n        self.is_in_train = True\n\n        # do_train is not a reliable argument, as it might not be set and .train() still called, so\n        # the following is a workaround:\n        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:\n            self._move_model_to_device(self.model, args.device)\n\n        if \"model_path\" in kwargs:\n            resume_from_checkpoint = kwargs.pop(\"model_path\")\n            warnings.warn(\n                \"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` \"\n                \"instead.\",\n                FutureWarning,\n            )\n        if len(kwargs) > 0:\n            raise TypeError(f\"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.\")\n        # This might change the seed so needs to run first.\n        self._hp_search_setup(trial)\n        self._train_batch_size = self.args.train_batch_size\n\n        # Model re-init\n        model_reloaded = False\n        if self.model_init is not None:\n            # Seed must be set before instantiating the model when using model_init.\n            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)\n            self.model = self.call_model_init(trial)\n            model_reloaded = True\n            # Reinitializes optimizer and scheduler\n            self.optimizer, self.lr_scheduler = None, None\n\n        # Load potential model checkpoint\n        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:\n            resume_from_checkpoint = get_last_checkpoint(args.output_dir)\n            if resume_from_checkpoint is None:\n                raise ValueError(f\"No valid checkpoint found in output directory ({args.output_dir})\")\n        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:\n            self._load_from_checkpoint(resume_from_checkpoint)\n        # If model was re-initialized, put it on the right device and update self.model_wrapped\n        if model_reloaded:\n            if self.place_model_on_device:\n                self._move_model_to_device(self.model, args.device)\n            self.model_wrapped = self.model\n\n        inner_training_loop = find_executable_batch_size(\n            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size\n        )\n        return inner_training_loop(\n            args=args,\n            resume_from_checkpoint=resume_from_checkpoint,\n            trial=trial,\n            ignore_keys_for_eval=ignore_keys_for_eval,\n        )\n\n    def _inner_training_loop(\n        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None\n    ):\n        self.accelerator.free_memory()\n        self._train_batch_size = batch_size\n        logger.debug(f\"Currently training with a batch size of: {self._train_batch_size}\")\n        # Data loader and number of training steps\n        train_dataloader = self.get_train_dataloaderd2()\n\n        # Setting up training control variables:\n        # number of training epochs: num_train_epochs\n        # number of training steps per epoch: num_update_steps_per_epoch\n        # total number of training steps to execute: max_steps\n        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size\n        len_dataloader = None\n        if args.max_steps<0:\n            args.max_steps=100\n        if has_length(train_dataloader):\n            len_dataloader = len(train_dataloader)\n            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps\n            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n            num_examples = self.num_examples(train_dataloader)\n            if args.max_steps > 0:\n                max_steps = args.max_steps\n                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(\n                    args.max_steps % num_update_steps_per_epoch > 0\n                )\n                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's\n                # the best we can do.\n                num_train_samples = args.max_steps * total_train_batch_size\n            else:\n                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)\n                num_train_epochs = math.ceil(args.num_train_epochs)\n                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs\n        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size\n            max_steps = args.max_steps\n            # Setting a very large number of epochs so we go as many times as necessary over the iterator.\n            num_train_epochs = sys.maxsize\n            num_update_steps_per_epoch = max_steps\n            num_examples = total_train_batch_size * args.max_steps\n            num_train_samples = args.max_steps * total_train_batch_size\n        else:\n            raise ValueError(\n                \"args.max_steps must be set to a positive value if dataloader does not have a length, was\"\n                f\" {args.max_steps}\"\n            )\n\n        # Compute absolute values for logging, eval, and save if given as ratio\n        if args.logging_steps and args.logging_steps < 1:\n            args.logging_steps = math.ceil(max_steps * args.logging_steps)\n        if args.eval_steps and args.eval_steps < 1:\n            args.eval_steps = math.ceil(max_steps * args.eval_steps)\n        if args.save_steps and args.save_steps < 1:\n            args.save_steps = math.ceil(max_steps * args.save_steps)\n\n        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:\n            if self.args.n_gpu > 1:\n                # nn.DataParallel(model) replicates the model, creating new variables and module\n                # references registered here no longer work on other gpus, breaking the module\n                raise ValueError(\n                    \"Currently --debug underflow_overflow is not supported under DP. Please use DDP\"\n                    \" (torch.distributed.launch).\"\n                )\n            else:\n                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa\n\n        delay_optimizer_creation = (\n            self.sharded_ddp is not None\n            and self.sharded_ddp != ShardedDDPOption.SIMPLE\n            or is_sagemaker_mp_enabled()\n            or self.fsdp is not None\n        )\n\n        # We need to reset the scheduler, as its parameters may be different on subsequent calls\n        if self._created_lr_scheduler:\n            self.lr_scheduler = None\n            self._created_lr_scheduler = False\n\n        if self.is_deepspeed_enabled:\n            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)\n\n        if not delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        self.state = TrainerState()\n        self.state.is_hyper_param_search = trial is not None\n\n        # Activate gradient checkpointing if needed\n        if args.gradient_checkpointing:\n            self.model.gradient_checkpointing_enable()\n\n        model = self._wrap_model(self.model_wrapped)\n\n        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:\n            self._load_from_checkpoint(resume_from_checkpoint, model)\n\n        # as the model is wrapped, don't use `accelerator.prepare`\n        # this is for unhandled cases such as\n        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX\n        use_accelerator_prepare = True if model is self.model else False\n\n        if delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        # prepare using `accelerator` prepare\n        if use_accelerator_prepare:\n            self.model.train()\n            if hasattr(self.lr_scheduler, \"step\"):\n                if self.use_apex:\n                    model = self.accelerator.prepare(self.model)\n                else:\n                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)\n            else:\n                # to handle cases wherein we pass \"DummyScheduler\" such as when it is specified in DeepSpeed config.\n                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(\n                    self.model, self.optimizer, self.lr_scheduler\n                )\n\n        if self.is_fsdp_enabled:\n            self.model = model\n\n        # for the rest of this function `model` is the outside model, whether it was wrapped or not\n        if model is not self.model:\n            self.model_wrapped = model\n\n        # backward compatibility\n        if self.is_deepspeed_enabled:\n            self.deepspeed = self.model_wrapped\n\n        # deepspeed ckpt loading\n        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:\n            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)#, load_optimizer_states=self.args.load_optimizer_states, load_lr_scheduler_states=self.args.load_lr_scheduler_states)\n        # Check if saved optimizer or scheduler states exist\n        self._load_optimizer_and_scheduler(resume_from_checkpoint)\n        # important: at this point:\n        # self.model         is the Transformers Model\n        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.\n\n        # Train!\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {num_examples:,}\")\n        logger.info(f\"  Num Epochs = {num_train_epochs:,}\")\n        logger.info(f\"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}\")\n        if self.args.per_device_train_batch_size != self._train_batch_size:\n            logger.info(f\"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}\")\n        logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}\")\n        logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n        logger.info(f\"  Total optimization steps = {max_steps:,}\")\n        logger.info(f\"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}\")\n\n        self.state.epoch = 0\n        start_time = time.time()\n        epochs_trained = 0\n        steps_trained_in_current_epoch = 0\n        steps_trained_progress_bar = None\n        # Check if continuing training from a checkpoint\n        if resume_from_checkpoint is not None and os.path.isfile(\n            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)\n        ):\n            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))\n            epochs_trained = self.state.global_step // num_update_steps_per_epoch\n            if not args.ignore_data_skip:\n                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)\n                steps_trained_in_current_epoch *= args.gradient_accumulation_steps\n            else:\n                steps_trained_in_current_epoch = 0\n\n            logger.info(\"  Continuing training from checkpoint, will skip to saved global_step\")\n            logger.info(f\"  Continuing training from epoch {epochs_trained}\")\n            logger.info(f\"  Continuing training from global step {self.state.global_step}\")\n            if not args.ignore_data_skip:\n                logger.info(\n                    f\"  Will skip the first {epochs_trained} epochs then the first\"\n                    f\" {steps_trained_in_current_epoch} batches in the first epoch.\"\n                )\n        # Update the references\n        self.callback_handler.model = self.model\n        self.callback_handler.optimizer = self.optimizer\n        self.callback_handler.lr_scheduler = self.lr_scheduler\n        self.callback_handler.train_dataloader = train_dataloader\n        if self.hp_name is not None and self._trial is not None:\n            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial\n            # parameter to Train when using DDP.\n            self.state.trial_name = self.hp_name(self._trial)\n        if trial is not None:\n            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial\n            self.state.trial_params = hp_params(assignments)\n        else:\n            self.state.trial_params = None\n        # This should be the same if the state has been saved but in case the training arguments changed, it's safer\n        # to set this after the load.\n        self.state.max_steps = max_steps\n        self.state.num_train_epochs = num_train_epochs\n        self.state.is_local_process_zero = self.is_local_process_zero()\n        self.state.is_world_process_zero = self.is_world_process_zero()\n\n        # tr_loss is a tensor to avoid synchronization of TPUs through .item()\n        tr_loss_ = torch.tensor(0.0).to(args.device)\n        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses\n        self._total_loss_scalar = 0.0\n        self._globalstep_last_logged = self.state.global_step\n        model.zero_grad()\n\n        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)\n\n        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.\n        if not args.ignore_data_skip:\n            for epoch in range(epochs_trained):\n                for _ in train_dataloader:\n                    break\n\n        total_batched_samples = 0\n        tr_loss = dict()\n        for epoch in range(epochs_trained, num_train_epochs):\n            epoch_iterator = train_dataloader\n\n            # Reset the past mems state at the beginning of each epoch if necessary.\n            if args.past_index >= 0:\n                self._past = None\n\n            steps_in_epoch = (\n                len(epoch_iterator)\n                if len_dataloader is not None\n                else args.max_steps * args.gradient_accumulation_steps\n            )\n            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)\n\n            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:\n                self._load_rng_state(resume_from_checkpoint)\n\n            rng_to_sync = False\n            steps_skipped = 0\n            # if steps_trained_in_current_epoch > 0:\n            #     epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)\n            #     steps_skipped = steps_trained_in_current_epoch\n            #     steps_trained_in_current_epoch = 0\n            #     rng_to_sync = True\n\n            step = -1\n            for step, inputs in enumerate(epoch_iterator):\n                total_batched_samples += 1\n                if rng_to_sync:\n                    self._load_rng_state(resume_from_checkpoint)\n                    rng_to_sync = False\n\n                # Skip past any already trained steps if resuming training\n                if steps_trained_in_current_epoch > 0:\n                    steps_trained_in_current_epoch =0\n                    if steps_trained_progress_bar is not None:\n                        steps_trained_progress_bar.update(steps_trained_in_current_epoch)\n                    if steps_trained_in_current_epoch == 0:\n                        self._load_rng_state(resume_from_checkpoint)\n                    continue\n                elif steps_trained_progress_bar is not None:\n                    steps_trained_progress_bar.close()\n                    steps_trained_progress_bar = None\n                if step % args.gradient_accumulation_steps == 0:\n                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)\n                with self.accelerator.accumulate(model):\n                    tr_loss_step = self.training_step(model, inputs)\n                if len(tr_loss)==0:\n                    tr_loss={k:tr_loss_.clone() for k in tr_loss_step.keys()}\n                for k, loss in tr_loss.items():\n                    if (\n                        args.logging_nan_inf_filter\n                        and not is_torch_tpu_available()\n                        and (torch.isnan(tr_loss_step[k]) or torch.isinf(tr_loss_step[k]))\n                    ):\n                        # if loss is nan or inf simply add the average of previous logged losses\n                        tr_loss[k] += loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                    else:\n                        tr_loss[k] += tr_loss_step[k]\n\n                # if (\n                #     args.logging_nan_inf_filter\n                #     and not is_torch_tpu_available()\n                #     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))\n                # ):\n                #     # if loss is nan or inf simply add the average of previous logged losses\n                #     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                # else:\n                #     tr_loss += tr_loss_step\n\n                self.current_flos += float(self.floating_point_ops(inputs))\n\n                is_last_step_and_steps_less_than_grad_acc = (\n                    steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch\n                )\n\n                if (\n                    total_batched_samples % args.gradient_accumulation_steps == 0\n                    or\n                    # last step in epoch but step is always smaller than gradient_accumulation_steps\n                    is_last_step_and_steps_less_than_grad_acc\n                ):\n                    # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered\n                    # in accelerate. So, explicitly enable sync gradients to True in that case.\n                    if is_last_step_and_steps_less_than_grad_acc or (\n                        version.parse(accelerate_version) <= version.parse(\"0.20.3\")\n                    ):\n                        self.accelerator.gradient_state._set_sync_gradients(True)\n\n                    # Gradient clipping\n                    if args.max_grad_norm is not None and args.max_grad_norm > 0:\n                        # deepspeed does its own clipping\n\n                        if self.do_grad_scaling:\n                            # Reduce gradients first for XLA\n                            if is_torch_tpu_available():\n                                gradients = xm._fetch_gradients(self.optimizer)\n                                xm.all_reduce(\"sum\", gradients, scale=1.0 / xm.xrt_world_size())\n                            # AMP: gradients need unscaling\n                            self.scaler.unscale_(self.optimizer)\n\n                        if is_sagemaker_mp_enabled() and args.fp16:\n                            self.optimizer.clip_master_grads(args.max_grad_norm)\n                        elif hasattr(self.optimizer, \"clip_grad_norm\"):\n                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping\n                            self.optimizer.clip_grad_norm(args.max_grad_norm)\n                        elif hasattr(model, \"clip_grad_norm_\"):\n                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping\n                            model.clip_grad_norm_(args.max_grad_norm)\n                        elif self.use_apex:\n                            # Revert to normal clipping otherwise, handling Apex or full precision\n                            nn.utils.clip_grad_norm_(\n                                amp.master_params(self.optimizer),\n                                args.max_grad_norm,\n                            )\n                        else:\n                            self.accelerator.clip_grad_norm_(\n                                model.parameters(),\n                                args.max_grad_norm,\n                            )\n\n                    # Optimizer step\n                    optimizer_was_run = True\n                    if is_torch_tpu_available():\n                        if self.do_grad_scaling:\n                            self.scaler.step(self.optimizer)\n                            self.scaler.update()\n                        else:\n                            # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step\n                            self.optimizer.step()\n                    elif self.do_grad_scaling:\n                        scale_before = self.scaler.get_scale()\n                        self.scaler.step(self.optimizer)\n                        self.scaler.update()\n                        scale_after = self.scaler.get_scale()\n                        optimizer_was_run = scale_before <= scale_after\n                    else:\n                        self.optimizer.step()\n                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped\n\n                    if optimizer_was_run:\n                        # Delay optimizer scheduling until metrics are generated\n                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                            self.lr_scheduler.step()\n\n                    model.zero_grad()\n                    self.state.global_step += 1\n                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch\n                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)\n\n                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n                else:\n                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)\n\n                if self.control.should_epoch_stop or self.control.should_training_stop:\n                    break\n            if step < 0:\n                logger.warning(\n                    \"There seems to be not a single sample in your epoch_iterator, stopping training at step\"\n                    f\" {self.state.global_step}! This is expected if you're using an IterableDataset and set\"\n                    f\" num_steps ({max_steps}) higher than the number of available samples.\"\n                )\n                self.control.should_training_stop = True\n\n            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)\n            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n\n            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n                if is_torch_tpu_available():\n                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n                    xm.master_print(met.metrics_report())\n                else:\n                    logger.warning(\n                        \"You enabled PyTorch/XLA debug metrics but you don't have a TPU \"\n                        \"configured. Check your training configuration if this is unexpected.\"\n                    )\n            if self.control.should_training_stop:\n                break\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of training\n            delattr(self, \"_past\")\n\n        logger.info(\"\\n\\nTraining completed. Do not forget to share your model on huggingface.co/models =)\\n\\n\")\n        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:\n            # Wait for everyone to get here so we are sur the model has been saved by process 0.\n            if is_torch_tpu_available():\n                xm.rendezvous(\"load_best_model_at_end\")\n            elif args.parallel_mode == ParallelMode.DISTRIBUTED:\n                dist.barrier()\n            elif is_sagemaker_mp_enabled():\n                smp.barrier()\n\n            self._load_best_model()\n\n        # add remaining tr_loss\n        # self._total_loss_scalar += tr_loss.item()\n        self._total_loss_scalar += tr_loss['loss_total'].item()\n\n        train_loss = self._total_loss_scalar / self.state.global_step\n\n        metrics = speed_metrics(\"train\", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)\n        self.store_flos()\n        metrics[\"total_flos\"] = self.state.total_flos\n        metrics[\"train_loss\"] = train_loss\n\n        self.is_in_train = False\n\n        self._memory_tracker.stop_and_update_metrics(metrics)\n\n        self.log(metrics)\n\n        run_dir = self._get_output_dir(trial)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)\n\n        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.\n        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:\n            for checkpoint in checkpoints_sorted:\n                if checkpoint != self.state.best_model_checkpoint:\n                    logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n                    shutil.rmtree(checkpoint)\n\n        self.control = self.callback_handler.on_train_end(args, self.state, self.control)\n\n        return TrainOutput(self.state.global_step, train_loss, metrics)\n\n    def _get_output_dir(self, trial):\n        if self.hp_search_backend is not None and trial is not None:\n            if self.hp_search_backend == HPSearchBackend.OPTUNA:\n                run_id = trial.number\n            elif self.hp_search_backend == HPSearchBackend.RAY:\n                from ray import tune\n\n                run_id = tune.get_trial_id()\n            elif self.hp_search_backend == HPSearchBackend.SIGOPT:\n                run_id = trial.id\n            elif self.hp_search_backend == HPSearchBackend.WANDB:\n                import wandb\n\n                run_id = wandb.run.id\n            run_name = self.hp_name(trial) if self.hp_name is not None else f\"run-{run_id}\"\n            run_dir = os.path.join(self.args.output_dir, run_name)\n        else:\n            run_dir = self.args.output_dir\n        return run_dir\n\n    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):\n        if model is None:\n            model = self.model\n\n        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)\n        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)\n        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)\n        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)\n        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)\n        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)\n        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)\n\n        if not any(\n            os.path.isfile(f)\n            for f in [\n                weights_file,\n                safe_weights_file,\n                weights_index_file,\n                safe_weights_index_file,\n                adapter_weights_file,\n                adapter_safe_weights_file,\n            ]\n        ):\n            raise ValueError(f\"Can't find a valid checkpoint at {resume_from_checkpoint}\")\n\n        logger.info(f\"Loading model from {resume_from_checkpoint}.\")\n        if os.path.isfile(config_file):\n            config = PretrainedConfig.from_json_file(config_file)\n            checkpoint_version = config.transformers_version\n            if checkpoint_version is not None and checkpoint_version != __version__:\n                logger.warning(\n                    f\"You are resuming training from a checkpoint trained with {checkpoint_version} of \"\n                    f\"Transformers but your current version is {__version__}. This is not recommended and could \"\n                    \"yield to errors or unwanted behaviors.\"\n                )\n        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):\n            # If the model is on the GPU, it still works!\n            if is_sagemaker_mp_enabled():\n                if os.path.isfile(os.path.join(resume_from_checkpoint, \"user_content.pt\")):\n                    # If the 'user_content.pt' file exists, load with the new smp api.\n                    # Checkpoint must have been saved with the new smp api.\n                    smp.resume_from_checkpoint(\n                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False\n                    )\n                else:\n                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.\n                    # Checkpoint must have been saved with the old smp api.\n                    if hasattr(self.args, \"fp16\") and self.args.fp16 is True:\n                        logger.warning(\n                            \"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported.\"\n                        )\n                    state_dict = torch.load(weights_file, map_location=\"cpu\")\n                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).\n                    state_dict[\"_smp_is_partial\"] = False\n                    load_result = model.load_state_dict(state_dict, strict=True)\n                    # release memory\n                    del state_dict\n            elif self.is_fsdp_enabled:\n                load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)\n            else:\n                # We load the model state dict on the CPU to avoid an OOM error.\n                if self.args.save_safetensors and os.path.isfile(safe_weights_file):\n                    state_dict = safetensors.torch.load_file(safe_weights_file, device=\"cpu\")\n                else:\n                    state_dict = torch.load(weights_file, map_location=\"cpu\")\n\n                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963\n                # which takes *args instead of **kwargs\n                load_result = model.load_state_dict(state_dict, False)\n                # release memory\n                del state_dict\n                self._issue_warnings_after_load(load_result)\n\n        # Load adapters following PR # 24096\n        elif is_peft_available() and isinstance(model, PeftModel):\n            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.\n            if hasattr(model, \"active_adapter\") and hasattr(model, \"load_adapter\"):\n                if os.path.exists(resume_from_checkpoint):\n                    model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)\n                else:\n                    logger.warning(\n                        \"The intermediate checkpoints of PEFT may not be saved correctly, \"\n                        f\"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. \"\n                        \"Check some examples here: https://github.com/huggingface/peft/issues/96\"\n                    )\n            else:\n                logger.warning(\"Could not load adapter model, make sure to have `peft>=0.3.0` installed\")\n        else:\n            # We load the sharded checkpoint\n            load_result = load_sharded_checkpoint(\n                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors\n            )\n            if not is_sagemaker_mp_enabled():\n                self._issue_warnings_after_load(load_result)\n\n    def _load_best_model(self):\n        logger.info(f\"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).\")\n        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)\n        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)\n        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)\n        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)\n\n        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model\n        if (\n            os.path.exists(best_model_path)\n            or os.path.exists(best_safe_model_path)\n            or os.path.exists(best_adapter_model_path)\n            or os.path.exists(best_safe_adapter_model_path)\n        ):\n            if self.is_deepspeed_enabled:\n                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)\n            else:\n                has_been_loaded = True\n                if is_sagemaker_mp_enabled():\n                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, \"user_content.pt\")):\n                        # If the 'user_content.pt' file exists, load with the new smp api.\n                        # Checkpoint must have been saved with the new smp api.\n                        smp.resume_from_checkpoint(\n                            path=self.state.best_model_checkpoint,\n                            tag=WEIGHTS_NAME,\n                            partial=False,\n                            load_optimizer=False,\n                        )\n                    else:\n                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.\n                        # Checkpoint must have been saved with the old smp api.\n                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):\n                            state_dict = safetensors.torch.load_file(best_safe_model_path, device=\"cpu\")\n                        else:\n                            state_dict = torch.load(best_model_path, map_location=\"cpu\")\n\n                        state_dict[\"_smp_is_partial\"] = False\n                        load_result = model.load_state_dict(state_dict, strict=True)\n                elif self.is_fsdp_enabled:\n                    load_fsdp_model(\n                        self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint\n                    )\n                else:\n                    if is_peft_available() and isinstance(model, PeftModel):\n                        # If train a model using PEFT & LoRA, assume that adapter have been saved properly.\n                        if hasattr(model, \"active_adapter\") and hasattr(model, \"load_adapter\"):\n                            if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):\n                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)\n                                # Load_adapter has no return value present, modify it when appropriate.\n                                from torch.nn.modules.module import _IncompatibleKeys\n\n                                load_result = _IncompatibleKeys([], [])\n                            else:\n                                logger.warning(\n                                    \"The intermediate checkpoints of PEFT may not be saved correctly, \"\n                                    f\"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. \"\n                                    \"Check some examples here: https://github.com/huggingface/peft/issues/96\"\n                                )\n                                has_been_loaded = False\n                        else:\n                            logger.warning(\"Could not load adapter model, make sure to have `peft>=0.3.0` installed\")\n                            has_been_loaded = False\n                    else:\n                        # We load the model state dict on the CPU to avoid an OOM error.\n                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):\n                            state_dict = safetensors.torch.load_file(best_safe_model_path, device=\"cpu\")\n                        else:\n                            state_dict = torch.load(best_model_path, map_location=\"cpu\")\n\n                        # If the model is on the GPU, it still works!\n                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963\n                        # which takes *args instead of **kwargs\n                        load_result = model.load_state_dict(state_dict, False)\n                if not is_sagemaker_mp_enabled() and has_been_loaded:\n                    self._issue_warnings_after_load(load_result)\n        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):\n            load_result = load_sharded_checkpoint(\n                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()\n            )\n            if not is_sagemaker_mp_enabled():\n                self._issue_warnings_after_load(load_result)\n        else:\n            logger.warning(\n                f\"Could not locate the best model at {best_model_path}, if you are running a distributed training \"\n                \"on multiple nodes, you should activate `--save_on_each_node`.\"\n            )\n\n    def _issue_warnings_after_load(self, load_result):\n        if len(load_result.missing_keys) != 0:\n            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(\n                self.model._keys_to_ignore_on_save\n            ):\n                self.model.tie_weights()\n            else:\n                logger.warning(f\"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.\")\n        if len(load_result.unexpected_keys) != 0:\n            logger.warning(\n                f\"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.\"\n            )\n\n    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):\n        if self.control.should_log:\n            if is_torch_tpu_available():\n                xm.mark_step()\n\n            logs: Dict[str, float] = {}\n\n            # all_gather + mean() to get average loss over all processes\n            # tr_loss_scalar = self._nested_gather(tr_loss).mean().item()\n            tr_loss_scalar = {k: self._nested_gather(tr_loss[k]).mean().item() for k in tr_loss.keys()}\n\n            # reset tr_loss to zero\n            for _,loss in tr_loss.items():\n                loss -= loss\n            # tr_loss -= tr_loss\n\n            # logs[\"loss\"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)\n            for k,loss in tr_loss_scalar.items():\n                logs[k]=round(loss / (self.state.global_step - self._globalstep_last_logged), 4)\n            logs[\"learning_rate\"] = self._get_learning_rate()\n\n            self._total_loss_scalar += tr_loss_scalar['loss_total']\n            self._globalstep_last_logged = self.state.global_step\n            self.store_flos()\n\n            self.log(logs)\n\n        metrics = None\n        if self.control.should_evaluate:\n            if isinstance(self.eval_dataset, dict):\n                metrics = {}\n                for eval_dataset_name, eval_dataset in self.eval_dataset.items():\n                    dataset_metrics = self.evaluate(\n                        eval_dataset=eval_dataset,\n                        ignore_keys=ignore_keys_for_eval,\n                        metric_key_prefix=f\"eval_{eval_dataset_name}\",\n                    )\n                    metrics.update(dataset_metrics)\n            else:\n                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)\n            self._report_to_hp_search(trial, self.state.global_step, metrics)\n\n            # Run delayed LR scheduler now that metrics are populated\n            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                metric_to_check = self.args.metric_for_best_model\n                if not metric_to_check.startswith(\"eval_\"):\n                    metric_to_check = f\"eval_{metric_to_check}\"\n                self.lr_scheduler.step(metrics[metric_to_check])\n\n        if self.control.should_save:\n            self._save_checkpoint(model, trial, metrics=metrics)\n            self.control = self.callback_handler.on_save(self.args, self.state, self.control)\n\n    def _load_rng_state(self, checkpoint):\n        # Load RNG states from `checkpoint`\n        if checkpoint is None:\n            return\n\n        if self.args.world_size > 1:\n            process_index = self.args.process_index\n            rng_file = os.path.join(checkpoint, f\"rng_state_{process_index}.pth\")\n            if not os.path.isfile(rng_file):\n                logger.info(\n                    f\"Didn't find an RNG file for process {process_index}, if you are resuming a training that \"\n                    \"wasn't launched in a distributed fashion, reproducibility is not guaranteed.\"\n                )\n                return\n        else:\n            rng_file = os.path.join(checkpoint, \"rng_state.pth\")\n            if not os.path.isfile(rng_file):\n                logger.info(\n                    \"Didn't find an RNG file, if you are resuming a training that was launched in a distributed \"\n                    \"fashion, reproducibility is not guaranteed.\"\n                )\n                return\n\n        checkpoint_rng_state = torch.load(rng_file)\n        random.setstate(checkpoint_rng_state[\"python\"])\n        np.random.set_state(checkpoint_rng_state[\"numpy\"])\n        torch.random.set_rng_state(checkpoint_rng_state[\"cpu\"])\n        if torch.cuda.is_available():\n            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n                torch.cuda.random.set_rng_state_all(checkpoint_rng_state[\"cuda\"])\n            else:\n                try:\n                    torch.cuda.random.set_rng_state(checkpoint_rng_state[\"cuda\"])\n                except Exception as e:\n                    logger.info(\n                        f\"Didn't manage to set back the RNG states of the GPU because of the following error:\\n {e}\"\n                        \"\\nThis won't yield the same results as if the training had not been interrupted.\"\n                    )\n        if is_torch_tpu_available():\n            xm.set_rng_state(checkpoint_rng_state[\"xla\"])\n\n    def _save_checkpoint(self, model, trial, metrics=None):\n        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we\n        # want to save except FullyShardedDDP.\n        # assert unwrap_model(model) is self.model, \"internal model should be a reference to self.model\"\n\n        # Save model checkpoint\n        checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n\n        if self.hp_search_backend is None and trial is None:\n            self.store_flos()\n\n        run_dir = self._get_output_dir(trial=trial)\n        output_dir = os.path.join(run_dir, checkpoint_folder)\n        self.save_model(output_dir, _internal_call=True)\n        if self.is_deepspeed_enabled:\n            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed\n            # config `stage3_gather_16bit_weights_on_model_save` is True\n            self.model_wrapped.save_checkpoint(output_dir)\n\n        # Save optimizer and scheduler\n        if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n            self.optimizer.consolidate_state_dict()\n\n        if self.fsdp or self.is_fsdp_enabled:\n            if self.is_fsdp_enabled:\n                save_fsdp_optimizer(\n                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir\n                )\n            else:\n                # FSDP has a different interface for saving optimizer states.\n                # Needs to be called on all ranks to gather all states.\n                # full_optim_state_dict will be deprecated after Pytorch 2.2!\n                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)\n\n        if is_torch_tpu_available():\n            xm.rendezvous(\"saving_optimizer_states\")\n            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n            with warnings.catch_warnings(record=True) as caught_warnings:\n                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n                reissue_pt_warnings(caught_warnings)\n        elif is_sagemaker_mp_enabled():\n            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)\n            smp.barrier()\n            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:\n                smp.save(\n                    opt_state_dict,\n                    os.path.join(output_dir, OPTIMIZER_NAME),\n                    partial=True,\n                    v3=smp.state.cfg.shard_optimizer_state,\n                )\n            if self.args.should_save:\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n                reissue_pt_warnings(caught_warnings)\n                if self.do_grad_scaling:\n                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))\n        elif self.args.should_save and not self.is_deepspeed_enabled:\n            # deepspeed.save_checkpoint above saves model/optim/sched\n            if self.fsdp and not self.is_fsdp_enabled:\n                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))\n            else:\n                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))\n\n            with warnings.catch_warnings(record=True) as caught_warnings:\n                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))\n            reissue_pt_warnings(caught_warnings)\n            if self.do_grad_scaling:\n                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))\n\n        # Determine the new best metric / best model checkpoint\n        if metrics is not None and self.args.metric_for_best_model is not None:\n            metric_to_check = self.args.metric_for_best_model\n            if not metric_to_check.startswith(\"eval_\"):\n                metric_to_check = f\"eval_{metric_to_check}\"\n            metric_value = metrics[metric_to_check]\n\n            operator = np.greater if self.args.greater_is_better else np.less\n            if (\n                self.state.best_metric is None\n                or self.state.best_model_checkpoint is None\n                or operator(metric_value, self.state.best_metric)\n            ):\n                self.state.best_metric = metric_value\n                self.state.best_model_checkpoint = output_dir\n\n        # Save the Trainer state\n        if self.args.should_save:\n            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))\n\n        # Save RNG state in non-distributed training\n        rng_states = {\n            \"python\": random.getstate(),\n            \"numpy\": np.random.get_state(),\n            \"cpu\": torch.random.get_rng_state(),\n        }\n        if torch.cuda.is_available():\n            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)\n                rng_states[\"cuda\"] = torch.cuda.random.get_rng_state_all()\n            else:\n                rng_states[\"cuda\"] = torch.cuda.random.get_rng_state()\n\n        if is_torch_tpu_available():\n            rng_states[\"xla\"] = xm.get_rng_state()\n\n        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may\n        # not yet exist.\n        os.makedirs(output_dir, exist_ok=True)\n\n        if self.args.world_size <= 1:\n            torch.save(rng_states, os.path.join(output_dir, \"rng_state.pth\"))\n        else:\n            torch.save(rng_states, os.path.join(output_dir, f\"rng_state_{self.args.process_index}.pth\"))\n\n        if self.args.push_to_hub:\n            self._push_from_checkpoint(output_dir)\n\n        # Maybe delete some older checkpoints.\n        if self.args.should_save:\n            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)\n\n    def _load_optimizer_and_scheduler(self, checkpoint):\n        \"\"\"If optimizer and scheduler states exist, load them.\"\"\"\n        if checkpoint is None:\n            return\n\n        if self.is_deepspeed_enabled:\n            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init\n            return\n\n        checkpoint_file_exists = (\n            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + \"_*\")\n            if is_sagemaker_mp_enabled()\n            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))\n        )\n        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):\n            # Load in optimizer and scheduler states\n            if is_torch_tpu_available():\n                # On TPU we have to take some extra precautions to properly load the states on the right device.\n                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=\"cpu\")\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location=\"cpu\")\n                reissue_pt_warnings(caught_warnings)\n\n                xm.send_cpu_data_to_device(optimizer_state, self.args.device)\n                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)\n\n                self.optimizer.load_state_dict(optimizer_state)\n                self.lr_scheduler.load_state_dict(lr_scheduler_state)\n            else:\n                if is_sagemaker_mp_enabled():\n                    if os.path.isfile(os.path.join(checkpoint, \"user_content.pt\")):\n                        # Optimizer checkpoint was saved with smp >= 1.10\n                        def opt_load_hook(mod, opt):\n                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))\n\n                    else:\n                        # Optimizer checkpoint was saved with smp < 1.10\n                        def opt_load_hook(mod, opt):\n                            if IS_SAGEMAKER_MP_POST_1_10:\n                                opt.load_state_dict(\n                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)\n                                )\n                            else:\n                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))\n\n                    self.model_wrapped.register_post_step_hook(opt_load_hook)\n                else:\n                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.\n                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more\n                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state\n                    map_location = self.args.device if self.args.world_size > 1 else \"cpu\"\n                    if self.fsdp or self.is_fsdp_enabled:\n                        if self.is_fsdp_enabled:\n                            load_fsdp_optimizer(\n                                self.accelerator.state.fsdp_plugin,\n                                self.accelerator,\n                                self.optimizer,\n                                self.model,\n                                checkpoint,\n                            )\n                        else:\n                            full_osd = None\n                            # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it\n                            if self.args.process_index == 0:\n                                full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))\n                            # call scatter_full_optim_state_dict on all ranks\n                            sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)\n                            self.optimizer.load_state_dict(sharded_osd)\n                    else:\n                        self.optimizer.load_state_dict(\n                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)\n                        )\n                with warnings.catch_warnings(record=True) as caught_warnings:\n                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))\n                reissue_pt_warnings(caught_warnings)\n                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):\n                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))\n\n    def hyperparameter_search(\n        self,\n        hp_space: Optional[Callable[[\"optuna.Trial\"], Dict[str, float]]] = None,\n        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,\n        n_trials: int = 20,\n        direction: str = \"minimize\",\n        backend: Optional[Union[\"str\", HPSearchBackend]] = None,\n        hp_name: Optional[Callable[[\"optuna.Trial\"], str]] = None,\n        **kwargs,\n    ) -> BestRun:\n        \"\"\"\n        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined\n        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,\n        the sum of all metrics otherwise.\n\n        <Tip warning={true}>\n\n        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to\n        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to\n        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom\n        optimizer/scheduler.\n\n        </Tip>\n\n        Args:\n            hp_space (`Callable[[\"optuna.Trial\"], Dict[str, float]]`, *optional*):\n                A function that defines the hyperparameter search space. Will default to\n                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or\n                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.\n            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):\n                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`\n                method. Will default to [`~trainer_utils.default_compute_objective`].\n            n_trials (`int`, *optional*, defaults to 100):\n                The number of trial runs to test.\n            direction (`str`, *optional*, defaults to `\"minimize\"`):\n                Whether to optimize greater or lower objects. Can be `\"minimize\"` or `\"maximize\"`, you should pick\n                `\"minimize\"` when optimizing the validation loss, `\"maximize\"` when optimizing one or several metrics.\n            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):\n                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending\n                on which one is installed. If all are installed, will default to optuna.\n            hp_name (`Callable[[\"optuna.Trial\"], str]]`, *optional*):\n                A function that defines the trial/run name. Will default to None.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more\n                information see:\n\n                - the documentation of\n                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)\n                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)\n                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)\n\n        Returns:\n            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in\n            `run_summary` attribute for Ray backend.\n        \"\"\"\n        if backend is None:\n            backend = default_hp_search_backend()\n        backend = HPSearchBackend(backend)\n        backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()\n        backend_obj.ensure_available()\n        self.hp_search_backend = backend\n        if self.model_init is None:\n            raise RuntimeError(\n                \"To use hyperparameter search, you need to pass your model through a model_init function.\"\n            )\n\n        self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space\n        self.hp_name = hp_name\n        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective\n\n        best_run = backend_obj.run(self, n_trials, direction, **kwargs)\n\n        self.hp_search_backend = None\n        return best_run\n\n    def log(self, logs: Dict[str, float]) -> None:\n        \"\"\"\n        Log `logs` on the various objects watching training.\n\n        Subclass and override this method to inject custom behavior.\n\n        Args:\n            logs (`Dict[str, float]`):\n                The values to log.\n        \"\"\"\n        if self.state.epoch is not None:\n            logs[\"epoch\"] = round(self.state.epoch, 2)\n\n        output = {**logs, **{\"step\": self.state.global_step}}\n        self.state.log_history.append(output)\n        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)\n\n    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:\n        \"\"\"\n        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.\n        \"\"\"\n        if isinstance(data, Mapping):\n            return type(data)({k: self._prepare_input(v) for k, v in data.items()})\n        elif isinstance(data, (tuple, list)):\n            return type(data)(self._prepare_input(v) for v in data)\n        elif isinstance(data, torch.Tensor):\n            kwargs = {\"device\": self.args.device}\n            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):\n                # NLP models inputs are int/uint and those get adjusted to the right dtype of the\n                # embedding. Other models such as wav2vec2's inputs are already float and thus\n                # may need special handling to match the dtypes of the model\n                kwargs.update({\"dtype\": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})\n            return data.to(**kwargs)\n        return data\n\n    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:\n        \"\"\"\n        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and\n        handling potential state.\n        \"\"\"\n        inputs = self._prepare_input(inputs)\n        if len(inputs) == 0:\n            raise ValueError(\n                \"The batch received was empty, your model won't be able to train on it. Double-check that your \"\n                f\"training dataset contains keys expected by the model: {','.join(self._signature_columns)}.\"\n            )\n        if self.args.past_index >= 0 and self._past is not None:\n            inputs[\"mems\"] = self._past\n\n        return inputs\n\n    def compute_loss_context_manager(self):\n        \"\"\"\n        A helper wrapper to group together context managers.\n        \"\"\"\n        return self.autocast_smart_context_manager()\n\n    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):\n        \"\"\"\n        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired\n        arguments, depending on the situation.\n        \"\"\"\n        if self.use_cuda_amp or self.use_cpu_amp:\n            ctx_manager = (\n                torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)\n                if self.use_cpu_amp\n                else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)\n            )\n        else:\n            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()\n\n        return ctx_manager\n\n    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:\n        \"\"\"\n        Perform a training step on a batch of inputs.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to train.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n\n        Return:\n            `torch.Tensor`: The tensor with training loss on this batch.\n        \"\"\"\n        model.train()\n        inputs = self._prepare_inputs(inputs)\n\n        if is_sagemaker_mp_enabled():\n            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)\n            return loss_mb.reduce_mean().detach().to(self.args.device)\n\n        with self.compute_loss_context_manager():\n            loss = self.compute_loss(model, inputs)\n\n\n        if self.args.n_gpu > 1:\n            for k, ls in loss.items():\n                loss[k] = loss[k].mean()  # mean() to average on multi-gpu parallel training\n\n        if self.do_grad_scaling:\n            self.scaler.scale(loss['loss_total']).backward()\n        elif self.use_apex:\n            with amp.scale_loss(loss['loss_total'], self.optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            self.accelerator.backward(loss['loss_total'])\n\n        # return loss.detach() / self.args.gradient_accumulation_steps\n        return {k:v.detach()/self.args.gradient_accumulation_steps for k,v in loss.items()}\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n        \"\"\"\n        How the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n        Subclass and override for custom behavior.\n        \"\"\"\n        if self.label_smoother is not None and \"labels\" in inputs:\n            labels = inputs.pop(\"labels\")\n        else:\n            labels = None\n        outputs = model(**inputs)\n        # Save past state if it exists\n        # TODO: this needs to be fixed and made cleaner later.\n        if self.args.past_index >= 0:\n            self._past = outputs[self.args.past_index]\n\n        if labels is not None:\n            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():\n                loss = self.label_smoother(outputs, labels, shift_labels=True)\n            else:\n                loss = self.label_smoother(outputs, labels)\n        else:\n            if isinstance(outputs, dict) and \"loss\" not in outputs:\n                raise ValueError(\n                    \"The model did not return a loss from the inputs, only the following keys: \"\n                    f\"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.\"\n                )\n            # We don't use .loss here since the model may return tuples instead of ModelOutput.\n            loss = outputs[\"loss\"] if isinstance(outputs, dict) else outputs[0]\n\n        return (loss, outputs) if return_outputs else loss\n\n    def is_local_process_zero(self) -> bool:\n        \"\"\"\n        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several\n        machines) main process.\n        \"\"\"\n        return self.args.local_process_index == 0\n\n    def is_world_process_zero(self) -> bool:\n        \"\"\"\n        Whether or not this process is the global main process (when training in a distributed fashion on several\n        machines, this is only going to be `True` for one process).\n        \"\"\"\n        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global\n        # process index.\n        if is_sagemaker_mp_enabled():\n            return smp.rank() == 0\n        else:\n            return self.args.process_index == 0\n\n    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):\n        \"\"\"\n        Will save the model, so you can reload it using `from_pretrained()`.\n\n        Will only save from the main process.\n        \"\"\"\n\n        if output_dir is None:\n            output_dir = self.args.output_dir\n\n        if is_torch_tpu_available():\n            self._save_tpu(output_dir)\n        elif is_sagemaker_mp_enabled():\n            # Calling the state_dict needs to be done on the wrapped model and on all processes.\n            os.makedirs(output_dir, exist_ok=True)\n            state_dict = self.model_wrapped.state_dict()\n            if self.args.should_save:\n                self._save(output_dir, state_dict=state_dict)\n            if IS_SAGEMAKER_MP_POST_1_10:\n                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10\n                Path(os.path.join(output_dir, \"user_content.pt\")).touch()\n        elif (\n            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp\n            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp\n            or self.fsdp is not None\n            or self.is_fsdp_enabled\n        ):\n            state_dict = self.model.state_dict()\n            if self.args.should_save:\n                self._save(output_dir, state_dict=state_dict)\n            if self.is_fsdp_enabled:\n                save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)\n\n        elif self.is_deepspeed_enabled:\n            # this takes care of everything as long as we aren't under zero3\n            if version.parse(accelerate_version) <= version.parse(\"0.20.3\"):\n                raise ValueError(\"Install Accelerate from main branch\")\n            try:\n                state_dict = self.accelerator.get_state_dict(self.deepspeed)\n                if self.args.should_save:\n                    self._save(output_dir, state_dict=state_dict)\n            except ValueError:\n                logger.warning(\n                    \" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use\"\n                    \" zero_to_fp32.py to recover weights\"\n                )\n                self.model_wrapped.save_checkpoint(output_dir)\n\n        elif self.args.should_save:\n            self._save(output_dir)\n\n        # Push to the Hub when `save_model` is called by the user.\n        if self.args.push_to_hub and not _internal_call:\n            self.push_to_hub(commit_message=\"Model save\")\n\n    def _save_tpu(self, output_dir: Optional[str] = None):\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        logger.info(f\"Saving model checkpoint to {output_dir}\")\n\n        if xm.is_master_ordinal():\n            os.makedirs(output_dir, exist_ok=True)\n            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n        # Save a trained model and configuration using `save_pretrained()`.\n        # They can then be reloaded using `from_pretrained()`\n        xm.rendezvous(\"saving_checkpoint\")\n        if not isinstance(self.model, PreTrainedModel):\n            if isinstance(unwrap_model(self.model), PreTrainedModel):\n                unwrap_model(self.model).save_pretrained(\n                    output_dir,\n                    is_main_process=self.args.should_save,\n                    state_dict=self.model.state_dict(),\n                    save_function=xm.save,\n                )\n            else:\n                logger.info(\"Trainer.model is not a `PreTrainedModel`, only saving its state dict.\")\n                state_dict = self.model.state_dict()\n                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n        else:\n            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)\n        if self.tokenizer is not None and self.args.should_save:\n            self.tokenizer.save_pretrained(output_dir)\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        # If we are executing this function, we are the process zero, so we don't check for that.\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n        logger.info(f\"Saving model checkpoint to {output_dir}\")\n\n        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)\n        # Save a trained model and configuration using `save_pretrained()`.\n        # They can then be reloaded using `from_pretrained()`\n        if not isinstance(self.model, supported_classes):\n            if state_dict is None:\n                state_dict = self.model.state_dict()\n\n            if isinstance(unwrap_model(self.model), supported_classes):\n                unwrap_model(self.model).save_pretrained(\n                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors\n                )\n            else:\n                logger.info(\"Trainer.model is not a `PreTrainedModel`, only saving its state dict.\")\n                if self.args.save_safetensors:\n                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))\n                else:\n                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n        else:\n            self.model.save_pretrained(\n                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors\n            )\n\n        if self.tokenizer is not None:\n            self.tokenizer.save_pretrained(output_dir)\n\n        # Good practice: save your training arguments together with the trained model\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n    def store_flos(self):\n        # Storing the number of floating-point operations that went into the model\n        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            self.state.total_flos += (\n                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()\n            )\n            self.current_flos = 0\n        else:\n            self.state.total_flos += self.current_flos\n            self.current_flos = 0\n\n    def _sorted_checkpoints(\n        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False\n    ) -> List[str]:\n        ordering_and_checkpoint_path = []\n\n        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f\"{checkpoint_prefix}-*\") if os.path.isdir(x)]\n\n        for path in glob_checkpoints:\n            if use_mtime:\n                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n            else:\n                regex_match = re.match(f\".*{checkpoint_prefix}-([0-9]+)\", path)\n                if regex_match is not None and regex_match.groups() is not None:\n                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n\n        checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n        # Make sure we don't delete the best model.\n        if self.state.best_model_checkpoint is not None:\n            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))\n            for i in range(best_model_index, len(checkpoints_sorted) - 2):\n                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]\n        return checkpoints_sorted\n\n    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:\n        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:\n            return\n\n        # Check if we should delete older checkpoint(s)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)\n        if len(checkpoints_sorted) <= self.args.save_total_limit:\n            return\n\n        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which\n        # we don't do to allow resuming.\n        save_total_limit = self.args.save_total_limit\n        if (\n            self.state.best_model_checkpoint is not None\n            and self.args.save_total_limit == 1\n            and checkpoints_sorted[-1] != self.state.best_model_checkpoint\n        ):\n            save_total_limit = 2\n\n        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)\n        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n        for checkpoint in checkpoints_to_be_deleted:\n            logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n            shutil.rmtree(checkpoint, ignore_errors=True)\n\n    def evaluate(\n        self,\n        eval_dataset: Optional[Dataset] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> Dict[str, float]:\n        \"\"\"\n        Run evaluation and returns metrics.\n\n        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent\n        (pass it to the init `compute_metrics` argument).\n\n        You can also subclass and override this method to inject custom behavior.\n\n        Args:\n            eval_dataset (`Dataset`, *optional*):\n                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns\n                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`\n                method.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n            metric_key_prefix (`str`, *optional*, defaults to `\"eval\"`):\n                An optional prefix to be used as the metrics key prefix. For example the metrics \"bleu\" will be named\n                \"eval_bleu\" if the prefix is \"eval\" (default)\n\n        Returns:\n            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The\n            dictionary also contains the epoch number which comes from the training state.\n        \"\"\"\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        eval_dataloader = self.get_eval_dataloader(eval_dataset)\n        start_time = time.time()\n\n        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop\n        output = eval_loop(\n            eval_dataloader,\n            description=\"Evaluation\",\n            # No point gathering the predictions if there are no metrics, otherwise we defer to\n            # self.args.prediction_loss_only\n            prediction_loss_only=True if self.compute_metrics is None else None,\n            ignore_keys=ignore_keys,\n            metric_key_prefix=metric_key_prefix,\n        )\n\n        total_batch_size = self.args.eval_batch_size * self.args.world_size\n        if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics:\n            start_time += output.metrics[f\"{metric_key_prefix}_jit_compilation_time\"]\n        output.metrics.update(\n            speed_metrics(\n                metric_key_prefix,\n                start_time,\n                num_samples=output.num_samples,\n                num_steps=math.ceil(output.num_samples / total_batch_size),\n            )\n        )\n\n        self.log(output.metrics)\n\n        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n            xm.master_print(met.metrics_report())\n\n        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)\n\n        self._memory_tracker.stop_and_update_metrics(output.metrics)\n\n        return output.metrics\n\n    def predict(\n        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = \"test\"\n    ) -> PredictionOutput:\n        \"\"\"\n        Run prediction and returns predictions and potential metrics.\n\n        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method\n        will also return metrics, like in `evaluate()`.\n\n        Args:\n            test_dataset (`Dataset`):\n                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the\n                `model.forward()` method are automatically removed. Has to implement the method `__len__`\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n            metric_key_prefix (`str`, *optional*, defaults to `\"test\"`):\n                An optional prefix to be used as the metrics key prefix. For example the metrics \"bleu\" will be named\n                \"test_bleu\" if the prefix is \"test\" (default)\n\n        <Tip>\n\n        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding\n        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into\n        one array. The padding index is -100.\n\n        </Tip>\n\n        Returns: *NamedTuple* A namedtuple with the following keys:\n\n            - predictions (`np.ndarray`): The predictions on `test_dataset`.\n            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).\n            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained\n              labels).\n        \"\"\"\n        # memory metrics - must set up as early as possible\n        self._memory_tracker.start()\n\n        test_dataloader = self.get_test_dataloader(test_dataset)\n        start_time = time.time()\n\n        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop\n        output = eval_loop(\n            test_dataloader, description=\"Prediction\", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix\n        )\n        total_batch_size = self.args.eval_batch_size * self.args.world_size\n        if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics:\n            start_time += output.metrics[f\"{metric_key_prefix}_jit_compilation_time\"]\n        output.metrics.update(\n            speed_metrics(\n                metric_key_prefix,\n                start_time,\n                num_samples=output.num_samples,\n                num_steps=math.ceil(output.num_samples / total_batch_size),\n            )\n        )\n\n        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)\n        self._memory_tracker.stop_and_update_metrics(output.metrics)\n\n        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)\n\n    def evaluation_loop(\n        self,\n        dataloader: DataLoader,\n        description: str,\n        prediction_loss_only: Optional[bool] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> EvalLoopOutput:\n        \"\"\"\n        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.\n\n        Works both with or without labels.\n        \"\"\"\n        args = self.args\n\n        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only\n\n        # if eval is called w/o train, handle model prep here\n        if self.is_deepspeed_enabled and self.deepspeed is None:\n            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)\n\n        model = self._wrap_model(self.model, training=False, dataloader=dataloader)\n\n        if len(self.accelerator._models) == 0 and model is self.model:\n            model = (\n                self.accelerator.prepare(model)\n                if self.is_deepspeed_enabled\n                else self.accelerator.prepare_model(model, evaluation_mode=True)\n            )\n\n            if self.is_fsdp_enabled:\n                self.model = model\n\n            # for the rest of this function `model` is the outside model, whether it was wrapped or not\n            if model is not self.model:\n                self.model_wrapped = model\n\n            # backward compatibility\n            if self.is_deepspeed_enabled:\n                self.deepspeed = self.model_wrapped\n\n        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called\n        # while ``train`` is running, cast it to the right dtype first and then put on device\n        if not self.is_in_train:\n            if args.fp16_full_eval:\n                model = model.to(dtype=torch.float16, device=args.device)\n            elif args.bf16_full_eval:\n                model = model.to(dtype=torch.bfloat16, device=args.device)\n\n        batch_size = self.args.eval_batch_size\n\n        logger.info(f\"***** Running {description} *****\")\n        if has_length(dataloader):\n            logger.info(f\"  Num examples = {self.num_examples(dataloader)}\")\n        else:\n            logger.info(\"  Num examples: Unknown\")\n        logger.info(f\"  Batch size = {batch_size}\")\n\n        model.eval()\n\n        self.callback_handler.eval_dataloader = dataloader\n        # Do this before wrapping.\n        eval_dataset = getattr(dataloader, \"dataset\", None)\n\n        if args.past_index >= 0:\n            self._past = None\n\n        # Initialize containers\n        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)\n        losses_host = None\n        preds_host = None\n        labels_host = None\n        inputs_host = None\n\n        # losses/preds/labels on CPU (final containers)\n        all_losses = None\n        all_preds = None\n        all_labels = None\n        all_inputs = None\n        # Will be useful when we have an iterable dataset so don't know its length.\n\n        observed_num_examples = 0\n        # Main evaluation loop\n        for step, inputs in enumerate(dataloader):\n            # Update the observed num examples\n            observed_batch_size = find_batch_size(inputs)\n            if observed_batch_size is not None:\n                observed_num_examples += observed_batch_size\n                # For batch samplers, batch_size is not known by the dataloader in advance.\n                if batch_size is None:\n                    batch_size = observed_batch_size\n\n            # Prediction step\n            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)\n            inputs_decode = self._prepare_input(inputs[\"input_ids\"]) if args.include_inputs_for_metrics else None\n\n            if is_torch_tpu_available():\n                xm.mark_step()\n\n            # Update containers on host\n            if loss is not None:\n                losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))\n                losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)\n            if labels is not None:\n                labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n            if inputs_decode is not None:\n                inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)\n                inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))\n                inputs_host = (\n                    inputs_decode\n                    if inputs_host is None\n                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)\n                )\n            if logits is not None:\n                logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)\n                if self.preprocess_logits_for_metrics is not None:\n                    logits = self.preprocess_logits_for_metrics(logits, labels)\n                logits = self.accelerator.gather_for_metrics((logits))\n                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)\n\n            if labels is not None:\n                labels = self.accelerator.gather_for_metrics((labels))\n                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)\n\n            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)\n\n            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.\n            if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:\n                if losses_host is not None:\n                    losses = nested_numpify(losses_host)\n                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)\n                if preds_host is not None:\n                    logits = nested_numpify(preds_host)\n                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)\n                if inputs_host is not None:\n                    inputs_decode = nested_numpify(inputs_host)\n                    all_inputs = (\n                        inputs_decode\n                        if all_inputs is None\n                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)\n                    )\n                if labels_host is not None:\n                    labels = nested_numpify(labels_host)\n                    all_labels = (\n                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)\n                    )\n\n                # Set back to None to begin a new accumulation\n                losses_host, preds_host, inputs_host, labels_host = None, None, None, None\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of the evaluation loop\n            delattr(self, \"_past\")\n\n        # Gather all remaining tensors and put them back on the CPU\n        if losses_host is not None:\n            losses = nested_numpify(losses_host)\n            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)\n        if preds_host is not None:\n            logits = nested_numpify(preds_host)\n            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)\n        if inputs_host is not None:\n            inputs_decode = nested_numpify(inputs_host)\n            all_inputs = (\n                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)\n            )\n        if labels_host is not None:\n            labels = nested_numpify(labels_host)\n            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)\n\n        # Number of samples\n        if has_length(eval_dataset):\n            num_samples = len(eval_dataset)\n        # The instance check is weird and does not actually check for the type, but whether the dataset has the right\n        # methods. Therefore we need to make sure it also has the attribute.\n        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, \"num_examples\", 0) > 0:\n            num_samples = eval_dataset.num_examples\n        else:\n            if has_length(dataloader):\n                num_samples = self.num_examples(dataloader)\n            else:  # both len(dataloader.dataset) and len(dataloader) fail\n                num_samples = observed_num_examples\n        if num_samples == 0 and observed_num_examples > 0:\n            num_samples = observed_num_examples\n\n        # Metrics!\n        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:\n            if args.include_inputs_for_metrics:\n                metrics = self.compute_metrics(\n                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)\n                )\n            else:\n                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))\n        else:\n            metrics = {}\n\n        # To be JSON-serializable, we need to remove numpy types or zero-d tensors\n        metrics = denumpify_detensorize(metrics)\n\n        if all_losses is not None:\n            metrics[f\"{metric_key_prefix}_loss\"] = all_losses.mean().item()\n        if hasattr(self, \"jit_compilation_time\"):\n            metrics[f\"{metric_key_prefix}_jit_compilation_time\"] = self.jit_compilation_time\n\n        # Prefix all keys with metric_key_prefix + '_'\n        for key in list(metrics.keys()):\n            if not key.startswith(f\"{metric_key_prefix}_\"):\n                metrics[f\"{metric_key_prefix}_{key}\"] = metrics.pop(key)\n\n        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)\n\n    def _nested_gather(self, tensors, name=None):\n        \"\"\"\n        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before\n        concatenating them to `gathered`\n        \"\"\"\n        if tensors is None:\n            return\n        if is_torch_tpu_available():\n            if name is None:\n                name = \"nested_gather\"\n            tensors = nested_xla_mesh_reduce(tensors, name)\n        elif is_sagemaker_mp_enabled():\n            tensors = smp_gather(tensors)\n        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != \"NO\") or (\n            self.args.distributed_state is None and self.args.local_rank != -1\n        ):\n            tensors = distributed_concat(tensors)\n        return tensors\n\n    def prediction_step(\n        self,\n        model: nn.Module,\n        inputs: Dict[str, Union[torch.Tensor, Any]],\n        prediction_loss_only: bool,\n        ignore_keys: Optional[List[str]] = None,\n    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:\n        \"\"\"\n        Perform an evaluation step on `model` using `inputs`.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to evaluate.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n            prediction_loss_only (`bool`):\n                Whether or not to return the loss only.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n\n        Return:\n            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,\n            logits and labels (each being optional).\n        \"\"\"\n        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)\n        # For CLIP-like models capable of returning loss values.\n        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`\n        # is `True` in `model.forward`.\n        return_loss = inputs.get(\"return_loss\", None)\n        if return_loss is None:\n            return_loss = self.can_return_loss\n        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False\n\n        inputs = self._prepare_inputs(inputs)\n        if ignore_keys is None:\n            if hasattr(self.model, \"config\"):\n                ignore_keys = getattr(self.model.config, \"keys_to_ignore_at_inference\", [])\n            else:\n                ignore_keys = []\n\n        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.\n        if has_labels or loss_without_labels:\n            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))\n            if len(labels) == 1:\n                labels = labels[0]\n        else:\n            labels = None\n\n        with torch.no_grad():\n            if is_sagemaker_mp_enabled():\n                raw_outputs = smp_forward_only(model, inputs)\n                if has_labels or loss_without_labels:\n                    if isinstance(raw_outputs, dict):\n                        loss_mb = raw_outputs[\"loss\"]\n                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + [\"loss\"])\n                    else:\n                        loss_mb = raw_outputs[0]\n                        logits_mb = raw_outputs[1:]\n\n                    loss = loss_mb.reduce_mean().detach().cpu()\n                    logits = smp_nested_concat(logits_mb)\n                else:\n                    loss = None\n                    if isinstance(raw_outputs, dict):\n                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)\n                    else:\n                        logits_mb = raw_outputs\n                    logits = smp_nested_concat(logits_mb)\n            else:\n                if has_labels or loss_without_labels:\n                    with self.compute_loss_context_manager():\n                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)\n                    loss = loss.mean().detach()\n\n                    if isinstance(outputs, dict):\n                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + [\"loss\"])\n                    else:\n                        logits = outputs[1:]\n                else:\n                    loss = None\n                    with self.compute_loss_context_manager():\n                        outputs = model(**inputs)\n                    if isinstance(outputs, dict):\n                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)\n                    else:\n                        logits = outputs\n                    # TODO: this needs to be fixed and made cleaner later.\n                    if self.args.past_index >= 0:\n                        self._past = outputs[self.args.past_index - 1]\n\n        if prediction_loss_only:\n            return (loss, None, None)\n\n        logits = nested_detach(logits)\n        if len(logits) == 1:\n            logits = logits[0]\n\n        return (loss, logits, labels)\n\n    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):\n        \"\"\"\n        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point\n        operations for every backward + forward pass. If using another model, either implement such a method in the\n        model or subclass and override this method.\n\n        Args:\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n        Returns:\n            `int`: The number of floating-point operations.\n        \"\"\"\n        if hasattr(self.model, \"floating_point_ops\"):\n            return self.model.floating_point_ops(inputs)\n        else:\n            return 0\n\n    def init_git_repo(self, at_init: bool = False):\n        \"\"\"\n        Initializes a git repo in `self.args.hub_model_id`.\n\n        Args:\n            at_init (`bool`, *optional*, defaults to `False`):\n                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is\n                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped\n                out.\n        \"\"\"\n        if not self.is_world_process_zero():\n            return\n        if self.args.hub_model_id is None:\n            repo_name = Path(self.args.output_dir).absolute().name\n        else:\n            repo_name = self.args.hub_model_id\n        if \"/\" not in repo_name:\n            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)\n\n        # Make sure the repo exists.\n        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)\n        try:\n            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)\n        except EnvironmentError:\n            if self.args.overwrite_output_dir and at_init:\n                # Try again after wiping output_dir\n                shutil.rmtree(self.args.output_dir)\n                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)\n            else:\n                raise\n\n        self.repo.git_pull()\n\n        # By default, ignore the checkpoint folders\n        if (\n            not os.path.exists(os.path.join(self.args.output_dir, \".gitignore\"))\n            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS\n        ):\n            with open(os.path.join(self.args.output_dir, \".gitignore\"), \"w\", encoding=\"utf-8\") as writer:\n                writer.writelines([\"checkpoint-*/\"])\n\n        # Add \"*.sagemaker\" to .gitignore if using SageMaker\n        if os.environ.get(\"SM_TRAINING_ENV\"):\n            self._add_sm_patterns_to_gitignore()\n\n        self.push_in_progress = None\n\n    def create_model_card(\n        self,\n        language: Optional[str] = None,\n        license: Optional[str] = None,\n        tags: Union[str, List[str], None] = None,\n        model_name: Optional[str] = None,\n        finetuned_from: Optional[str] = None,\n        tasks: Union[str, List[str], None] = None,\n        dataset_tags: Union[str, List[str], None] = None,\n        dataset: Union[str, List[str], None] = None,\n        dataset_args: Union[str, List[str], None] = None,\n    ):\n        \"\"\"\n        Creates a draft of a model card using the information available to the `Trainer`.\n\n        Args:\n            language (`str`, *optional*):\n                The language of the model (if applicable)\n            license (`str`, *optional*):\n                The license of the model. Will default to the license of the pretrained model used, if the original\n                model given to the `Trainer` comes from a repo on the Hub.\n            tags (`str` or `List[str]`, *optional*):\n                Some tags to be included in the metadata of the model card.\n            model_name (`str`, *optional*):\n                The name of the model.\n            finetuned_from (`str`, *optional*):\n                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo\n                of the original model given to the `Trainer` (if it comes from the Hub).\n            tasks (`str` or `List[str]`, *optional*):\n                One or several task identifiers, to be included in the metadata of the model card.\n            dataset_tags (`str` or `List[str]`, *optional*):\n                One or several dataset tags, to be included in the metadata of the model card.\n            dataset (`str` or `List[str]`, *optional*):\n                One or several dataset identifiers, to be included in the metadata of the model card.\n            dataset_args (`str` or `List[str]`, *optional*):\n               One or several dataset arguments, to be included in the metadata of the model card.\n        \"\"\"\n        if not self.is_world_process_zero():\n            return\n\n        training_summary = TrainingSummary.from_trainer(\n            self,\n            language=language,\n            license=license,\n            tags=tags,\n            model_name=model_name,\n            finetuned_from=finetuned_from,\n            tasks=tasks,\n            dataset_tags=dataset_tags,\n            dataset=dataset,\n            dataset_args=dataset_args,\n        )\n        model_card = training_summary.to_model_card()\n        with open(os.path.join(self.args.output_dir, \"README.md\"), \"w\") as f:\n            f.write(model_card)\n\n    def _push_from_checkpoint(self, checkpoint_folder):\n        # Only push from one node.\n        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:\n            return\n        # If we haven't finished the last push, we don't do this one.\n        if self.push_in_progress is not None and not self.push_in_progress.is_done:\n            return\n\n        output_dir = self.args.output_dir\n        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder\n        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]\n        if is_peft_available():\n            modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])\n        for modeling_file in modeling_files:\n            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):\n                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))\n        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.\n        if self.tokenizer is not None:\n            self.tokenizer.save_pretrained(output_dir)\n        # Same for the training arguments\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n\n        try:\n            if self.args.hub_strategy == HubStrategy.CHECKPOINT:\n                # Temporarily move the checkpoint just saved for the push\n                tmp_checkpoint = os.path.join(output_dir, \"last-checkpoint\")\n                # We have to remove the \"last-checkpoint\" dir if it exists, otherwise the checkpoint is moved as a\n                # subfolder.\n                if os.path.isdir(tmp_checkpoint):\n                    shutil.rmtree(tmp_checkpoint)\n                shutil.move(checkpoint_folder, tmp_checkpoint)\n\n            if self.args.save_strategy == IntervalStrategy.STEPS:\n                commit_message = f\"Training in progress, step {self.state.global_step}\"\n            else:\n                commit_message = f\"Training in progress, epoch {int(self.state.epoch)}\"\n            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)\n            # Return type of `Repository.push_to_hub` is either None or a tuple.\n            if push_work is not None:\n                self.push_in_progress = push_work[1]\n        except Exception as e:\n            logger.error(f\"Error when pushing to hub: {e}\")\n        finally:\n            if self.args.hub_strategy == HubStrategy.CHECKPOINT:\n                # Move back the checkpoint to its place\n                shutil.move(tmp_checkpoint, checkpoint_folder)\n\n    def push_to_hub(self, commit_message: Optional[str] = \"End of training\", blocking: bool = True, **kwargs) -> str:\n        \"\"\"\n        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.\n\n        Parameters:\n            commit_message (`str`, *optional*, defaults to `\"End of training\"`):\n                Message to commit while pushing.\n            blocking (`bool`, *optional*, defaults to `True`):\n                Whether the function should return only when the `git push` has finished.\n            kwargs (`Dict[str, Any]`, *optional*):\n                Additional keyword arguments passed along to [`~Trainer.create_model_card`].\n\n        Returns:\n            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of\n            the commit and an object to track the progress of the commit if `blocking=True`\n        \"\"\"\n        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but\n        # it might fail.\n        if not hasattr(self, \"repo\"):\n            self.init_git_repo()\n\n        model_name = kwargs.pop(\"model_name\", None)\n        if model_name is None and self.args.should_save:\n            if self.args.hub_model_id is None:\n                model_name = Path(self.args.output_dir).name\n            else:\n                model_name = self.args.hub_model_id.split(\"/\")[-1]\n\n        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by\n        # self.args.should_save.\n        self.save_model(_internal_call=True)\n\n        # Only push from one node.\n        if not self.is_world_process_zero():\n            return\n\n        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.\n        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:\n            self.push_in_progress._process.kill()\n            self.push_in_progress = None\n\n        git_head_commit_url = self.repo.push_to_hub(\n            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True\n        )\n        # push separately the model card to be independant from the rest of the model\n        if self.args.should_save:\n            self.create_model_card(model_name=model_name, **kwargs)\n            try:\n                self.repo.push_to_hub(\n                    commit_message=\"update model card README.md\", blocking=blocking, auto_lfs_prune=True\n                )\n            except EnvironmentError as exc:\n                logger.error(f\"Error pushing update to the model card. Please read logs and retry.\\n${exc}\")\n\n        return git_head_commit_url\n\n    #\n    # Deprecated code\n    #\n\n    def prediction_loop(\n        self,\n        dataloader: DataLoader,\n        description: str,\n        prediction_loss_only: Optional[bool] = None,\n        ignore_keys: Optional[List[str]] = None,\n        metric_key_prefix: str = \"eval\",\n    ) -> EvalLoopOutput:\n        \"\"\"\n        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.\n\n        Works both with or without labels.\n        \"\"\"\n        args = self.args\n\n        if not has_length(dataloader):\n            raise ValueError(\"dataloader must implement a working __len__\")\n\n        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only\n\n        # if eval is called w/o train, handle model prep here\n        if self.is_deepspeed_enabled and self.deepspeed is None:\n            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)\n\n        model = self._wrap_model(self.model, training=False, dataloader=dataloader)\n\n        if len(self.accelerator._models) == 0 and model is self.model:\n            model = (\n                self.accelerator.prepare(model)\n                if self.is_deepspeed_enabled\n                else self.accelerator.prepare_model(model, evaluation_mode=True)\n            )\n\n            if self.is_fsdp_enabled:\n                self.model = model\n\n            # for the rest of this function `model` is the outside model, whether it was wrapped or not\n            if model is not self.model:\n                self.model_wrapped = model\n\n            # backward compatibility\n            if self.is_deepspeed_enabled:\n                self.deepspeed = self.model_wrapped\n\n        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called\n        # while ``train`` is running, cast it to the right dtype first and then put on device\n        if not self.is_in_train:\n            if args.fp16_full_eval:\n                model = model.to(dtype=torch.float16, device=args.device)\n            elif args.bf16_full_eval:\n                model = model.to(dtype=torch.bfloat16, device=args.device)\n\n        batch_size = dataloader.batch_size\n        num_examples = self.num_examples(dataloader)\n        logger.info(f\"***** Running {description} *****\")\n        logger.info(f\"  Num examples = {num_examples}\")\n        logger.info(f\"  Batch size = {batch_size}\")\n        losses_host: torch.Tensor = None\n        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None\n        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None\n        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None\n\n        world_size = max(1, args.world_size)\n\n        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)\n        if not prediction_loss_only:\n            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass\n            # a batch size to the sampler)\n            make_multiple_of = None\n            if hasattr(dataloader, \"sampler\") and isinstance(dataloader.sampler, SequentialDistributedSampler):\n                make_multiple_of = dataloader.sampler.batch_size\n            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)\n\n        model.eval()\n\n        if args.past_index >= 0:\n            self._past = None\n\n        self.callback_handler.eval_dataloader = dataloader\n\n        for step, inputs in enumerate(dataloader):\n            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)\n            inputs_decode = self._prepare_input(inputs[\"input_ids\"]) if args.include_inputs_for_metrics else None\n\n            if loss is not None:\n                losses = loss.repeat(batch_size)\n                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)\n            if logits is not None:\n                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)\n            if labels is not None:\n                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)\n            if inputs_decode is not None:\n                inputs_host = (\n                    inputs_decode\n                    if inputs_host is None\n                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)\n                )\n            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)\n\n            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.\n            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:\n                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, \"eval_losses\"))\n                if not prediction_loss_only:\n                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, \"eval_preds\"))\n                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, \"eval_label_ids\"))\n                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, \"eval_inputs_ids\"))\n\n                # Set back to None to begin a new accumulation\n                losses_host, preds_host, labels_host, inputs_host = None, None, None, None\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of the evaluation loop\n            delattr(self, \"_past\")\n\n        # Gather all remaining tensors and put them back on the CPU\n        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, \"eval_losses\"))\n        if not prediction_loss_only:\n            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, \"eval_preds\"))\n            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, \"eval_label_ids\"))\n            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, \"eval_inputs_ids\"))\n\n        eval_loss = eval_losses_gatherer.finalize()\n        preds = preds_gatherer.finalize() if not prediction_loss_only else None\n        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None\n        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None\n\n        if self.compute_metrics is not None and preds is not None and label_ids is not None:\n            if args.include_inputs_for_metrics:\n                metrics = self.compute_metrics(\n                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)\n                )\n            else:\n                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))\n        else:\n            metrics = {}\n\n        # To be JSON-serializable, we need to remove numpy types or zero-d tensors\n        metrics = denumpify_detensorize(metrics)\n\n        if eval_loss is not None:\n            metrics[f\"{metric_key_prefix}_loss\"] = eval_loss.mean().item()\n\n        # Prefix all keys with metric_key_prefix + '_'\n        for key in list(metrics.keys()):\n            if not key.startswith(f\"{metric_key_prefix}_\"):\n                metrics[f\"{metric_key_prefix}_{key}\"] = metrics.pop(key)\n\n        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)\n\n    def _gather_and_numpify(self, tensors, name):\n        \"\"\"\n        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before\n        concatenating them to `gathered`\n        \"\"\"\n        if tensors is None:\n            return\n        if is_torch_tpu_available():\n            tensors = nested_xla_mesh_reduce(tensors, name)\n        elif is_sagemaker_mp_enabled():\n            tensors = smp_gather(tensors)\n        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:\n            tensors = distributed_concat(tensors)\n\n        return nested_numpify(tensors)\n\n    def _add_sm_patterns_to_gitignore(self) -> None:\n        \"\"\"Add SageMaker Checkpointing patterns to .gitignore file.\"\"\"\n        # Make sure we only do this on the main process\n        if not self.is_world_process_zero():\n            return\n\n        patterns = [\"*.sagemaker-uploading\", \"*.sagemaker-uploaded\"]\n\n        # Get current .gitignore content\n        if os.path.exists(os.path.join(self.repo.local_dir, \".gitignore\")):\n            with open(os.path.join(self.repo.local_dir, \".gitignore\"), \"r\") as f:\n                current_content = f.read()\n        else:\n            current_content = \"\"\n\n        # Add the patterns to .gitignore\n        content = current_content\n        for pattern in patterns:\n            if pattern not in content:\n                if content.endswith(\"\\n\"):\n                    content += pattern\n                else:\n                    content += f\"\\n{pattern}\"\n\n        # Write the .gitignore file if it has changed\n        if content != current_content:\n            with open(os.path.join(self.repo.local_dir, \".gitignore\"), \"w\") as f:\n                logger.debug(f\"Writing .gitignore file. Content: {content}\")\n                f.write(content)\n\n        self.repo.git_add(\".gitignore\")\n\n        # avoid race condition with git status\n        time.sleep(0.5)\n\n        if not self.repo.is_repo_clean():\n            self.repo.git_commit(\"Add *.sagemaker patterns to .gitignore.\")\n            self.repo.git_push()\n\n    def create_accelerator_and_postprocess(self):\n        grad_acc_kwargs = {\"num_steps\": self.args.gradient_accumulation_steps}\n        if version.parse(accelerate_version) > version.parse(\"0.20.3\"):\n            grad_acc_kwargs[\"sync_with_dataloader\"] = False\n        gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)\n\n        # create accelerator object\n        self.accelerator = Accelerator(\n            deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin\n        )\n\n        # deepspeed and accelerate flags covering both trainer args and accelerate launcher\n        self.is_deepspeed_enabled = getattr(self.accelerator.state, \"deepspeed_plugin\", None) is not None\n        self.is_fsdp_enabled = getattr(self.accelerator.state, \"fsdp_plugin\", None) is not None\n\n        # post accelerator creation setup\n        if self.is_fsdp_enabled:\n            fsdp_plugin = self.accelerator.state.fsdp_plugin\n            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(\n                \"limit_all_gathers\", fsdp_plugin.limit_all_gathers\n            )\n            fsdp_plugin.use_orig_params = self.args.fsdp_config.get(\"use_orig_params\", fsdp_plugin.use_orig_params)\n\n        if self.is_deepspeed_enabled:\n            if getattr(self.args, \"hf_deepspeed_config\", None) is None:\n                from transformers.deepspeed import HfTrainerDeepSpeedConfig\n\n                ds_plugin = self.accelerator.state.deepspeed_plugin\n\n                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)\n                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config\n                ds_plugin.hf_ds_config.trainer_config_process(self.args)\n\n\n\nclass LLaVATrainer(TrainerLLavaGD):\n\n    def _save_checkpoint(self, model, trial, metrics=None):\n        # if getattr(self.args, 'tune_mm_mlp_adapter', False):\n        #     from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n        #     checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n        #\n        #     run_dir = self._get_output_dir(trial=trial)\n        #     output_dir = os.path.join(run_dir, checkpoint_folder)\n        #\n        #     # Only save Adapter\n        #     keys_to_match = ['mm_projector']\n        #     if getattr(self.args, \"use_im_start_end\", False) or getattr(self.args, \"new_tokens\", False):\n        #         keys_to_match.extend(['embed_tokens', 'embed_in','lm_head'])\n        #     # import pdb; pdb.set_trace()\n        #     weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)\n        #\n        #     if self.args.local_rank == 0 or self.args.local_rank == -1:\n        #         self.model.config.save_pretrained(output_dir)\n        #         torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n        # else:\n        super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        # if getattr(self.args, 'tune_mm_mlp_adapter', False):\n        #     pass\n        # else:\n        super(LLaVATrainer, self)._save(output_dir, state_dict)\n"
  },
  {
    "path": "llava/train/train.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport torch\n\nimport transformers\n\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom torch.utils.data import Dataset\nfrom llava.train.llava_trainer import LLaVATrainer\n\nfrom llava import conversation as conversation_lib\nfrom llava.model import *\nfrom llava.mm_utils import tokenizer_image_token\n\nfrom PIL import Image\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    version: Optional[str] = field(default=\"v0\")\n    freeze_backbone: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n    mm_use_im_patch_token: bool = field(default=True)\n    mm_vision_select_feature: Optional[str] = field(default=\"patch\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n    image_grid_pinpoints: Optional[str] = field(default=None)\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    freeze_mm_mlp_adapter: bool = field(default=False)\n    mpt_attn_impl: Optional[str] = field(default=\"triton\")\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    lora_enable: bool = False\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    dbg: bool = False\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    data_args: DataArguments\n) -> Dict:\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess_llama_2(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2\n\n    # Mask targets\n    sep = \"[/INST] \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_mpt(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    targets = input_ids.clone()\n    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1]\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep)\n        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n        for conv_idx in range(3, len(rounds), 2):\n            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(re_rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))\n            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_plain(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        assert len(source) == 2\n        assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n        source[0]['value'] = DEFAULT_IMAGE_TOKEN\n        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep\n        conversations.append(conversation)\n    # tokenize conversations\n    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))\n        target[:tokenized_len] = IGNORE_INDEX\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:\n        return preprocess_plain(sources, tokenizer)\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:\n        return preprocess_llama_2(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version.startswith(\"v1\"):\n        return preprocess_v1(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version == \"mpt\":\n        return preprocess_mpt(sources, tokenizer)\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    def get_tokenize_len(prompts):\n        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]\n\n    if has_image:\n        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    else:\n        conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n        input_ids = conversations_tokenized[\"input_ids\"]\n\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        if has_image:\n            tokenized_lens = get_tokenize_len([header] + [s[\"value\"] for s in source])\n        else:\n            tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source], tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 data_args: DataArguments):\n        super(LazySupervisedDataset, self).__init__()\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        rank0_print(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.data_args = data_args\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        try:\n            sources = self.list_data_dict[i]\n            # print(1,'\\n')\n\n            if isinstance(i, int):\n                sources = [sources]\n            assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n            if 'image' in sources[0]:\n                # print(2)\n                # print(2, '\\n')\n\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.data_args.image_folder\n                processor = self.data_args.image_processor\n                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')\n                if self.data_args.image_aspect_ratio == 'pad':\n                    def expand2square(pil_img, background_color):\n                        width, height = pil_img.size\n                        if width == height:\n                            return pil_img\n                        elif width > height:\n                            result = Image.new(pil_img.mode, (width, width), background_color)\n                            result.paste(pil_img, (0, (width - height) // 2))\n                            return result\n                        else:\n                            result = Image.new(pil_img.mode, (height, height), background_color)\n                            result.paste(pil_img, ((height - width) // 2, 0))\n                            return result\n                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                else:\n                    # print(3, '\\n')\n\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.data_args)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n            data_dict = preprocess(\n                sources,\n                self.tokenizer,\n                has_image=('image' in self.list_data_dict[i]))\n            # print(4,'\\n')\n\n            if isinstance(i, int):\n                data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                                 labels=data_dict[\"labels\"][0])\n\n            # image exist in the data\n            if 'image' in self.list_data_dict[i]:\n                data_dict['image'] = image\n            elif self.data_args.is_multimodal:\n                # image does not exist in the data, but the model is multimodal\n                crop_size = self.data_args.image_processor.crop_size\n                data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n            # print(6,'\\n')\n\n            return data_dict\n        except Exception as e:\n            print(self.list_data_dict[i]['image'], \"failed\")\n            return self.__getitem__(i + 1)\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        labels = labels[:, :self.tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                data_args=data_args)\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\n\ndef train():\n    global local_rank\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n\n    bnb_model_from_pretrained_args = {}\n    if training_args.bits in [4, 8]:\n        from transformers import BitsAndBytesConfig\n        bnb_model_from_pretrained_args.update(dict(\n            device_map={\"\": training_args.device},\n            load_in_4bit=training_args.bits == 4,\n            load_in_8bit=training_args.bits == 8,\n            quantization_config=BitsAndBytesConfig(\n                load_in_4bit=training_args.bits == 4,\n                load_in_8bit=training_args.bits == 8,\n                llm_int8_threshold=6.0,\n                llm_int8_has_fp16_weight=False,\n                bnb_4bit_compute_dtype=compute_dtype,\n                bnb_4bit_use_double_quant=training_args.double_quant,\n                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}\n            )\n        ))\n\n    if model_args.vision_tower is not None:\n        if 'mpt' in model_args.model_name_or_path:\n            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)\n            config.attn_config['attn_impl'] = training_args.mpt_attn_impl\n            model = LlavaMPTForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                config=config,\n                cache_dir=training_args.cache_dir,\n                **bnb_model_from_pretrained_args\n            )\n        else:\n            model = LlavaLlamaForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                cache_dir=training_args.cache_dir,\n                **bnb_model_from_pretrained_args\n            )\n    else:\n        model = transformers.LlamaForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=training_args.cache_dir,\n            **bnb_model_from_pretrained_args\n        )\n    model.config.use_cache = False\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    if training_args.gradient_checkpointing:\n        if hasattr(model, \"enable_input_require_grads\"):\n            model.enable_input_require_grads()\n        else:\n            def make_inputs_require_grad(module, input, output):\n                output.requires_grad_(True)\n            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        if training_args.bits == 16:\n            if training_args.bf16:\n                model.to(torch.bfloat16)\n            if training_args.fp16:\n                model.to(torch.float16)\n        rank0_print(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    if 'mpt' in model_args.model_name_or_path:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=training_args.cache_dir,\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\"\n        )\n    else:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=training_args.cache_dir,\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\",\n            use_fast=False,\n        )\n\n    if model_args.version == \"v0\":\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=\"[PAD]\"),\n                tokenizer=tokenizer,\n                model=model,\n            )\n    elif model_args.version == \"v0.5\":\n        tokenizer.pad_token = tokenizer.unk_token\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n        if model_args.version in conversation_lib.conv_templates:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]\n        else:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[\"vicuna_v1\"]\n\n    if model_args.vision_tower is not None:\n        model.get_model().initialize_vision_modules(\n            model_args=model_args,\n            fsdp=training_args.fsdp\n        )\n        \n        vision_tower = model.get_vision_tower()\n        vision_tower.to(dtype=torch.float16, device=training_args.device)\n\n        data_args.image_processor = vision_tower.image_processor\n        data_args.is_multimodal = True\n\n        model.config.image_aspect_ratio = data_args.image_aspect_ratio\n        model.config.image_grid_pinpoints = data_args.image_grid_pinpoints\n\n        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        if model_args.tune_mm_mlp_adapter:\n            model.requires_grad_(False)\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = True\n\n        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter\n        if training_args.freeze_mm_mlp_adapter:\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = False\n\n        if training_args.bits in [4, 8]:\n            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)\n\n        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        training_args.use_im_start_end = model_args.mm_use_im_start_end\n        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token\n        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)\n\n    if training_args.bits in [4, 8]:\n        from peft.tuners.lora import LoraLayer\n        for name, module in model.named_modules():\n            if isinstance(module, LoraLayer):\n                if training_args.bf16:\n                    module = module.to(torch.bfloat16)\n            if 'norm' in name:\n                module = module.to(torch.float32)\n            if 'lm_head' in name or 'embed_tokens' in name:\n                if hasattr(module, 'weight'):\n                    if training_args.bf16 and module.weight.dtype == torch.float32:\n                        module = module.to(torch.bfloat16)\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    trainer = LLaVATrainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    model.config.use_cache = True\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/train/train_grounding_1st.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\nreplace_llama_attn_with_flash_attn()\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport torch\n\nimport transformers\n\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom torch.utils.data import Dataset\nfrom llava.train.llava_trainer_gd import LLaVATrainer\n\nfrom llava import conversation as conversation_lib\nfrom llava.model import *\nfrom llava.mm_utils import tokenizer_image_token\n\nfrom PIL import Image\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    whole_model: Optional[str] = field(default=\"facebook/opt-125m\")\n    version: Optional[str] = field(default=\"v0\")\n    freeze_backbone: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n    load_model: bool = field(default=False)\n    mm_use_im_patch_token: bool = field(default=True)\n    mm_vision_select_feature: Optional[str] = field(default=\"patch\")\n    opt: Optional[str] = field(default=\"\")\n    config_file: Optional[str] = field(default=\"\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n    image_grid_pinpoints: Optional[str] = field(default=None)\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    freeze_mm_mlp_adapter: bool = field(default=False)\n    mpt_attn_impl: Optional[str] = field(default=\"triton\")\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    lora_enable: bool = False\n    new_tokens: bool = True\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    dbg: bool = False\n    load_optimizer_states: bool = True\n    load_lr_scheduler_states: bool = True\n    freeze_segmentation: bool = False\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    data_args: DataArguments\n) -> Dict:\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess_llama_2(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2\n\n    # Mask targets\n    sep = \"[/INST] \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_mpt(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    targets = input_ids.clone()\n    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1]\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep)\n        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n        for conv_idx in range(3, len(rounds), 2):\n            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(re_rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))\n            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_plain(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        assert len(source) == 2\n        assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n        source[0]['value'] = DEFAULT_IMAGE_TOKEN\n        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep\n        conversations.append(conversation)\n    # tokenize conversations\n    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))\n        target[:tokenized_len] = IGNORE_INDEX\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:\n        return preprocess_plain(sources, tokenizer)\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:\n        return preprocess_llama_2(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version.startswith(\"v1\"):\n        return preprocess_v1(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version == \"mpt\":\n        return preprocess_mpt(sources, tokenizer)\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    def get_tokenize_len(prompts):\n        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]\n\n    if has_image:\n        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    else:\n        conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n        input_ids = conversations_tokenized[\"input_ids\"]\n\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        if has_image:\n            tokenized_lens = get_tokenize_len([header] + [s[\"value\"] for s in source])\n        else:\n            tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source], tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 data_args: DataArguments):\n        super(LazySupervisedDataset, self).__init__()\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        rank0_print(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.data_args = data_args\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        try:\n            sources = self.list_data_dict[i]\n            if isinstance(i, int):\n                sources = [sources]\n            assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n            if 'image' in sources[0]:\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.data_args.image_folder\n                processor = self.data_args.image_processor\n                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')\n                if self.data_args.image_aspect_ratio == 'pad':\n                    def expand2square(pil_img, background_color):\n                        width, height = pil_img.size\n                        if width == height:\n                            return pil_img\n                        elif width > height:\n                            result = Image.new(pil_img.mode, (width, width), background_color)\n                            result.paste(pil_img, (0, (width - height) // 2))\n                            return result\n                        else:\n                            result = Image.new(pil_img.mode, (height, height), background_color)\n                            result.paste(pil_img, ((height - width) // 2, 0))\n                            return result\n                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                else:\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.data_args)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n            data_dict = preprocess(\n                sources,\n                self.tokenizer,\n                has_image=('image' in self.list_data_dict[i]))\n            if isinstance(i, int):\n                data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                                 labels=data_dict[\"labels\"][0])\n\n            # image exist in the data\n            if 'image' in self.list_data_dict[i]:\n                data_dict['image'] = image\n            elif self.data_args.is_multimodal:\n                # image does not exist in the data, but the model is multimodal\n                crop_size = self.data_args.image_processor.crop_size\n                data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n            return data_dict\n        except Exception:\n            print(self.list_data_dict[i], \"failed\")\n            return self.__getitem__(i + 1)\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        labels = labels[:, :self.tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                data_args=data_args)\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\nfrom detectron2.config import LazyConfig, instantiate\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = LazyConfig.load(args.config_file)\n    # import pdb;pdb.set_trace()\n    opt=args.opt.split(',')\n    cfg = LazyConfig.apply_overrides(cfg, opt)\n    # cfg.freeze()\n    # default_setup(cfg, args)\n    # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name=\"maskdino\")\n    return cfg\n\ndef train():\n    global local_rank\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n    cfg=setup(model_args)\n    bnb_model_from_pretrained_args = {}\n    if training_args.bits in [4, 8]:\n        from transformers import BitsAndBytesConfig\n        bnb_model_from_pretrained_args.update(dict(\n            device_map={\"\": training_args.device},\n            load_in_4bit=training_args.bits == 4,\n            load_in_8bit=training_args.bits == 8,\n            quantization_config=BitsAndBytesConfig(\n                load_in_4bit=training_args.bits == 4,\n                load_in_8bit=training_args.bits == 8,\n                llm_int8_threshold=6.0,\n                llm_int8_has_fp16_weight=False,\n                bnb_4bit_compute_dtype=compute_dtype,\n                bnb_4bit_use_double_quant=training_args.double_quant,\n                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}\n            )\n        ))\n\n    if model_args.vision_tower is not None:\n        if 'mpt' in model_args.model_name_or_path:\n            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\", trust_remote_code=True)\n            config.attn_config['attn_impl'] = training_args.mpt_attn_impl\n            model = LlavaMPTForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                config=config,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n        else:\n            model = LlavaLlamaForCausalLM_gd.from_pretrained(\n                model_args.model_name_or_path,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n    else:\n        model = transformers.LlamaForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            **bnb_model_from_pretrained_args\n        )\n    model.config.use_cache = False\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    if training_args.gradient_checkpointing:\n        if hasattr(model, \"enable_input_require_grads\"):\n            model.enable_input_require_grads()\n        else:\n            def make_inputs_require_grad(module, input, output):\n                output.requires_grad_(True)\n            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        if training_args.bits == 16:\n            if training_args.bf16:\n                model.to(torch.bfloat16)\n            if training_args.fp16:\n                model.to(torch.float16)\n        rank0_print(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    if 'mpt' in model_args.model_name_or_path:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\"\n        )\n    else:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\",\n            use_fast=False,\n        )\n\n    if model_args.version == \"v0\":\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=\"[PAD]\"),\n                tokenizer=tokenizer,\n                model=model,\n            )\n    elif model_args.version == \"v0.5\":\n        tokenizer.pad_token = tokenizer.unk_token\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n        if model_args.version in conversation_lib.conv_templates:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]\n        else:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[\"vicuna_v1\"]\n\n    if model_args.vision_tower is not None:\n        model.get_model().initialize_vision_modules(\n            model_args=model_args,\n            fsdp=training_args.fsdp\n        )\n        \n        vision_tower = model.get_vision_tower()\n        vision_tower.to(dtype=torch.float16, device=training_args.device)\n\n        data_args.image_processor = vision_tower.image_processor\n        data_args.is_multimodal = True\n\n        model.config.image_aspect_ratio = data_args.image_aspect_ratio\n        model.config.image_grid_pinpoints = data_args.image_grid_pinpoints\n\n        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        if model_args.tune_mm_mlp_adapter or training_args.dbg:\n            model.requires_grad_(False)\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = True\n\n        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter\n        if training_args.freeze_mm_mlp_adapter:\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = False\n\n        if training_args.bits in [4, 8]:\n            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)\n\n        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        training_args.use_im_start_end = model_args.mm_use_im_start_end\n        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token\n        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)\n\n    model.initialize_seg_modules(\n        cfg=cfg,\n    )\n    if training_args.freeze_segmentation:\n        model.freeze_seg_modules()\n\n\n    if training_args.bits in [4, 8]:\n        from peft.tuners.lora import LoraLayer\n        for name, module in model.named_modules():\n            if isinstance(module, LoraLayer):\n                if training_args.bf16:\n                    module = module.to(torch.bfloat16)\n            if 'norm' in name:\n                module = module.to(torch.float32)\n            if 'lm_head' in name or 'embed_tokens' in name:\n                if hasattr(module, 'weight'):\n                    if training_args.bf16 and module.weight.dtype == torch.float32:\n                        module = module.to(torch.bfloat16)\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    print(model)\n    if model_args.load_model:\n        loaded_dict = dict()\n        if \"stage1\" in model_args.whole_model:\n            old_emb_in=model.get_input_embeddings().weight.clone()\n            old_emb_out=model.get_output_embeddings().weight.clone()\n        for model_file in os.listdir(model_args.whole_model):\n            if model_file.endswith('.bin') and model_file.startswith('pytorch_model'):\n                loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu'))\n        model.load_state_dict(loaded_dict, strict=False)\n        if \"stage1\" in model_args.whole_model:\n            with torch.no_grad():\n                model.get_input_embeddings().weight[:-3]=old_emb_in[:-3]\n                model.get_output_embeddings().weight[:-3]=old_emb_out[:-3]\n        print(loaded_dict.keys())\n\n    trainer = LLaVATrainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess),\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    model.config.use_cache = True\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/train/train_joint_1st.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\nreplace_llama_attn_with_flash_attn()\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport torch\n\nimport transformers\n\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom torch.utils.data import Dataset\nfrom llava.train.llava_trainer_joint_train import LLaVATrainer\n\nfrom llava import conversation as conversation_lib\nfrom llava.model import *\nfrom llava.mm_utils import tokenizer_image_token\n\nfrom PIL import Image\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    whole_model: Optional[str] = field(default=\"\")\n    version: Optional[str] = field(default=\"v0\")\n    freeze_backbone: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n    load_model: bool = field(default=False)\n    mm_use_im_patch_token: bool = field(default=True)\n    mm_vision_select_feature: Optional[str] = field(default=\"patch\")\n    opt: Optional[str] = field(default=\"\")\n    config_file: Optional[str] = field(default=\"\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n    image_grid_pinpoints: Optional[str] = field(default=None)\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    freeze_mm_mlp_adapter: bool = field(default=False)\n    mpt_attn_impl: Optional[str] = field(default=\"triton\")\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    lora_enable: bool = False\n    new_tokens: bool = True\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    dbg: bool = False\n    load_optimizer_states: bool = True\n    load_lr_scheduler_states: bool = True\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    data_args: DataArguments\n) -> Dict:\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess_llama_2(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2\n\n    # Mask targets\n    sep = \"[/INST] \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_mpt(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    targets = input_ids.clone()\n    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1]\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep)\n        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n        for conv_idx in range(3, len(rounds), 2):\n            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(re_rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))\n            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_plain(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        assert len(source) == 2\n        assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n        source[0]['value'] = DEFAULT_IMAGE_TOKEN\n        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep\n        conversations.append(conversation)\n    # tokenize conversations\n    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))\n        target[:tokenized_len] = IGNORE_INDEX\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:\n        return preprocess_plain(sources, tokenizer)\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:\n        return preprocess_llama_2(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version.startswith(\"v1\"):\n        return preprocess_v1(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version == \"mpt\":\n        return preprocess_mpt(sources, tokenizer)\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    def get_tokenize_len(prompts):\n        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]\n\n    if has_image:\n        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    else:\n        conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n        input_ids = conversations_tokenized[\"input_ids\"]\n\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        if has_image:\n            tokenized_lens = get_tokenize_len([header] + [s[\"value\"] for s in source])\n        else:\n            tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source], tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 data_args: DataArguments):\n        super(LazySupervisedDataset, self).__init__()\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        rank0_print(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.data_args = data_args\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        try:\n            sources = self.list_data_dict[i]\n            if isinstance(i, int):\n                sources = [sources]\n            assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n            if 'image' in sources[0]:\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.data_args.image_folder\n                processor = self.data_args.image_processor\n                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')\n                if self.data_args.image_aspect_ratio == 'pad':\n                    def expand2square(pil_img, background_color):\n                        width, height = pil_img.size\n                        if width == height:\n                            return pil_img\n                        elif width > height:\n                            result = Image.new(pil_img.mode, (width, width), background_color)\n                            result.paste(pil_img, (0, (width - height) // 2))\n                            return result\n                        else:\n                            result = Image.new(pil_img.mode, (height, height), background_color)\n                            result.paste(pil_img, ((height - width) // 2, 0))\n                            return result\n                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                else:\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.data_args)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n            data_dict = preprocess(\n                sources,\n                self.tokenizer,\n                has_image=('image' in self.list_data_dict[i]))\n            if isinstance(i, int):\n                data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                                 labels=data_dict[\"labels\"][0])\n\n            # image exist in the data\n            if 'image' in self.list_data_dict[i]:\n                data_dict['image'] = image\n            elif self.data_args.is_multimodal:\n                # image does not exist in the data, but the model is multimodal\n                crop_size = self.data_args.image_processor.crop_size\n                data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n            return data_dict\n        except Exception:\n            print(self.list_data_dict[i], \"failed\")\n            return self.__getitem__(i + 1)\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        labels = labels[:, :self.tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n@dataclass\nclass DataCollatorForSupervisedDatasetEmpty(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        return instances\n        # input_ids, labels = tuple([instance[key] for instance in instances]\n        #                           for key in (\"input_ids\", \"labels\"))\n        # input_ids = torch.nn.utils.rnn.pad_sequence(\n        #     input_ids,\n        #     batch_first=True,\n        #     padding_value=self.tokenizer.pad_token_id)\n        # labels = torch.nn.utils.rnn.pad_sequence(labels,\n        #                                          batch_first=True,\n        #                                          padding_value=IGNORE_INDEX)\n        # input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        # labels = labels[:, :self.tokenizer.model_max_length]\n        # batch = dict(\n        #     input_ids=input_ids,\n        #     labels=labels,\n        #     attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        # )\n        #\n        # if 'image' in instances[0]:\n        #     images = [instance['image'] for instance in instances]\n        #     if all(x is not None and x.shape == images[0].shape for x in images):\n        #         batch['images'] = torch.stack(images)\n        #     else:\n        #         batch['images'] = images\n        #\n        # return batch\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                data_args=data_args)\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\nfrom detectron2.config import LazyConfig, instantiate\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = LazyConfig.load(args.config_file)\n    # import pdb;pdb.set_trace()\n    opt=args.opt.split(',')\n    cfg = LazyConfig.apply_overrides(cfg, opt)\n    # cfg.freeze()\n    # default_setup(cfg, args)\n    # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name=\"maskdino\")\n    return cfg\n\ndef train():\n    global local_rank\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n    cfg=setup(model_args)\n    bnb_model_from_pretrained_args = {}\n    if training_args.bits in [4, 8]:\n        from transformers import BitsAndBytesConfig\n        bnb_model_from_pretrained_args.update(dict(\n            device_map={\"\": training_args.device},\n            load_in_4bit=training_args.bits == 4,\n            load_in_8bit=training_args.bits == 8,\n            quantization_config=BitsAndBytesConfig(\n                load_in_4bit=training_args.bits == 4,\n                load_in_8bit=training_args.bits == 8,\n                llm_int8_threshold=6.0,\n                llm_int8_has_fp16_weight=False,\n                bnb_4bit_compute_dtype=compute_dtype,\n                bnb_4bit_use_double_quant=training_args.double_quant,\n                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}\n            )\n        ))\n\n    if model_args.vision_tower is not None:\n        if 'mpt' in model_args.model_name_or_path:\n            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\", trust_remote_code=True)\n            config.attn_config['attn_impl'] = training_args.mpt_attn_impl\n            model = LlavaMPTForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                config=config,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n        else:\n            model = LlavaLlamaForCausalLM_joint.from_pretrained(\n                model_args.model_name_or_path,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n    else:\n        model = transformers.LlamaForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            **bnb_model_from_pretrained_args\n        )\n    model.config.use_cache = False\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    if training_args.gradient_checkpointing:\n        if hasattr(model, \"enable_input_require_grads\"):\n            model.enable_input_require_grads()\n        else:\n            def make_inputs_require_grad(module, input, output):\n                output.requires_grad_(True)\n            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        if training_args.bits == 16:\n            if training_args.bf16:\n                model.to(torch.bfloat16)\n            if training_args.fp16:\n                model.to(torch.float16)\n        rank0_print(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    if 'mpt' in model_args.model_name_or_path:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\"\n        )\n    else:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\",\n            use_fast=False,\n        )\n\n    if model_args.version == \"v0\":\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=\"[PAD]\"),\n                tokenizer=tokenizer,\n                model=model,\n            )\n    elif model_args.version == \"v0.5\":\n        tokenizer.pad_token = tokenizer.unk_token\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n        if model_args.version in conversation_lib.conv_templates:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]\n        else:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[\"vicuna_v1\"]\n\n    if model_args.vision_tower is not None:\n        model.get_model().initialize_vision_modules(\n            model_args=model_args,\n            fsdp=training_args.fsdp\n        )\n        \n        vision_tower = model.get_vision_tower()\n        vision_tower.to(dtype=torch.float16, device=training_args.device)\n\n        data_args.image_processor = vision_tower.image_processor\n        data_args.is_multimodal = True\n\n        model.config.image_aspect_ratio = data_args.image_aspect_ratio\n        model.config.image_grid_pinpoints = data_args.image_grid_pinpoints\n\n        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        if model_args.tune_mm_mlp_adapter or training_args.dbg:\n            model.requires_grad_(False)\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = True\n\n        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter\n        if training_args.freeze_mm_mlp_adapter:\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = False\n\n        if training_args.bits in [4, 8]:\n            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)\n\n        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        training_args.use_im_start_end = model_args.mm_use_im_start_end\n        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token\n        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)\n    cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size\n    model.initialize_seg_modules(\n        cfg=cfg\n    )\n\n    if training_args.bits in [4, 8]:\n        from peft.tuners.lora import LoraLayer\n        for name, module in model.named_modules():\n            if isinstance(module, LoraLayer):\n                if training_args.bf16:\n                    module = module.to(torch.bfloat16)\n            if 'norm' in name:\n                module = module.to(torch.float32)\n            if 'lm_head' in name or 'embed_tokens' in name:\n                if hasattr(module, 'weight'):\n                    if training_args.bf16 and module.weight.dtype == torch.float32:\n                        module = module.to(torch.bfloat16)\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    print(model)\n    if model_args.load_model:\n        loaded_dict = dict()\n        if \"stage1\" in model_args.whole_model:\n            old_emb_in=model.get_input_embeddings().weight.clone()\n            old_emb_out=model.get_output_embeddings().weight.clone()\n        for model_file in os.listdir(model_args.whole_model):\n            if model_file.endswith('.bin') and model_file.startswith('pytorch_model'):\n                loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu'))\n\n        model.load_state_dict(loaded_dict, strict=False)\n        if \"stage1\" in model_args.whole_model:\n            with torch.no_grad():\n                model.get_input_embeddings().weight[:-3]=old_emb_in[:-3]\n                model.get_output_embeddings().weight[:-3]=old_emb_out[:-3]\n        print(loaded_dict.keys())\n\n    trainer = LLaVATrainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess),\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    model.config.use_cache = True\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/train/train_joint_2st.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\nreplace_llama_attn_with_flash_attn()\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport torch\n\nimport transformers\n\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom torch.utils.data import Dataset\nfrom llava.train.llava_trainer_joint_train import LLaVATrainer\n\nfrom llava import conversation as conversation_lib\nfrom llava.model import *\nfrom llava.mm_utils import tokenizer_image_token\n\nfrom PIL import Image\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    whole_model: Optional[str] = field(default=\"facebook/opt-125m\")\n    version: Optional[str] = field(default=\"v0\")\n    freeze_backbone: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n    load_model: bool = field(default=False)\n    mm_use_im_patch_token: bool = field(default=True)\n    mm_vision_select_feature: Optional[str] = field(default=\"patch\")\n    opt: Optional[str] = field(default=\"\")\n    config_file: Optional[str] = field(default=\"\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n    image_grid_pinpoints: Optional[str] = field(default=None)\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    freeze_mm_mlp_adapter: bool = field(default=False)\n    mpt_attn_impl: Optional[str] = field(default=\"triton\")\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    lora_enable: bool = False\n    new_tokens: bool = True\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    dbg: bool = False\n    load_optimizer_states: bool = True\n    load_lr_scheduler_states: bool = True\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    data_args: DataArguments\n) -> Dict:\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess_llama_2(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2\n\n    # Mask targets\n    sep = \"[/INST] \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_mpt(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    targets = input_ids.clone()\n    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1]\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep)\n        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n        for conv_idx in range(3, len(rounds), 2):\n            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(re_rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))\n            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_plain(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        assert len(source) == 2\n        assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n        source[0]['value'] = DEFAULT_IMAGE_TOKEN\n        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep\n        conversations.append(conversation)\n    # tokenize conversations\n    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))\n        target[:tokenized_len] = IGNORE_INDEX\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:\n        return preprocess_plain(sources, tokenizer)\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:\n        return preprocess_llama_2(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version.startswith(\"v1\"):\n        return preprocess_v1(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version == \"mpt\":\n        return preprocess_mpt(sources, tokenizer)\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    def get_tokenize_len(prompts):\n        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]\n\n    if has_image:\n        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    else:\n        conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n        input_ids = conversations_tokenized[\"input_ids\"]\n\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        if has_image:\n            tokenized_lens = get_tokenize_len([header] + [s[\"value\"] for s in source])\n        else:\n            tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source], tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 data_args: DataArguments):\n        super(LazySupervisedDataset, self).__init__()\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        rank0_print(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.data_args = data_args\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        try:\n            sources = self.list_data_dict[i]\n            if isinstance(i, int):\n                sources = [sources]\n            assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n            if 'image' in sources[0]:\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.data_args.image_folder\n                processor = self.data_args.image_processor\n                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')\n                if self.data_args.image_aspect_ratio == 'pad':\n                    def expand2square(pil_img, background_color):\n                        width, height = pil_img.size\n                        if width == height:\n                            return pil_img\n                        elif width > height:\n                            result = Image.new(pil_img.mode, (width, width), background_color)\n                            result.paste(pil_img, (0, (width - height) // 2))\n                            return result\n                        else:\n                            result = Image.new(pil_img.mode, (height, height), background_color)\n                            result.paste(pil_img, ((height - width) // 2, 0))\n                            return result\n                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                else:\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.data_args)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n            data_dict = preprocess(\n                sources,\n                self.tokenizer,\n                has_image=('image' in self.list_data_dict[i]))\n            if isinstance(i, int):\n                data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                                 labels=data_dict[\"labels\"][0])\n\n            # image exist in the data\n            if 'image' in self.list_data_dict[i]:\n                data_dict['image'] = image\n            elif self.data_args.is_multimodal:\n                # image does not exist in the data, but the model is multimodal\n                crop_size = self.data_args.image_processor.crop_size\n                data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n            return data_dict\n        except Exception:\n            print(self.list_data_dict[i], \"failed\")\n            return self.__getitem__(i + 1)\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        labels = labels[:, :self.tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n@dataclass\nclass DataCollatorForSupervisedDatasetEmpty(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        return instances\n        # input_ids, labels = tuple([instance[key] for instance in instances]\n        #                           for key in (\"input_ids\", \"labels\"))\n        # input_ids = torch.nn.utils.rnn.pad_sequence(\n        #     input_ids,\n        #     batch_first=True,\n        #     padding_value=self.tokenizer.pad_token_id)\n        # labels = torch.nn.utils.rnn.pad_sequence(labels,\n        #                                          batch_first=True,\n        #                                          padding_value=IGNORE_INDEX)\n        # input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        # labels = labels[:, :self.tokenizer.model_max_length]\n        # batch = dict(\n        #     input_ids=input_ids,\n        #     labels=labels,\n        #     attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        # )\n        #\n        # if 'image' in instances[0]:\n        #     images = [instance['image'] for instance in instances]\n        #     if all(x is not None and x.shape == images[0].shape for x in images):\n        #         batch['images'] = torch.stack(images)\n        #     else:\n        #         batch['images'] = images\n        #\n        # return batch\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                data_args=data_args)\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\nfrom detectron2.config import LazyConfig, instantiate\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = LazyConfig.load(args.config_file)\n    # import pdb;pdb.set_trace()\n    opt=args.opt.split(',')\n    cfg = LazyConfig.apply_overrides(cfg, opt)\n    # cfg.freeze()\n    # default_setup(cfg, args)\n    # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name=\"maskdino\")\n    return cfg\n\ndef train():\n    global local_rank\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n    cfg=setup(model_args)\n    bnb_model_from_pretrained_args = {}\n    if training_args.bits in [4, 8]:\n        from transformers import BitsAndBytesConfig\n        bnb_model_from_pretrained_args.update(dict(\n            device_map={\"\": training_args.device},\n            load_in_4bit=training_args.bits == 4,\n            load_in_8bit=training_args.bits == 8,\n            quantization_config=BitsAndBytesConfig(\n                load_in_4bit=training_args.bits == 4,\n                load_in_8bit=training_args.bits == 8,\n                llm_int8_threshold=6.0,\n                llm_int8_has_fp16_weight=False,\n                bnb_4bit_compute_dtype=compute_dtype,\n                bnb_4bit_use_double_quant=training_args.double_quant,\n                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}\n            )\n        ))\n\n    if model_args.vision_tower is not None:\n        if 'mpt' in model_args.model_name_or_path:\n            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\", trust_remote_code=True)\n            config.attn_config['attn_impl'] = training_args.mpt_attn_impl\n            model = LlavaMPTForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                config=config,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n        else:\n            model = LlavaLlamaForCausalLM_joint_2st.from_pretrained(\n                model_args.model_name_or_path,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n    else:\n        model = transformers.LlamaForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            **bnb_model_from_pretrained_args\n        )\n    model.config.use_cache = False\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    if training_args.gradient_checkpointing:\n        if hasattr(model, \"enable_input_require_grads\"):\n            model.enable_input_require_grads()\n        else:\n            def make_inputs_require_grad(module, input, output):\n                output.requires_grad_(True)\n            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        if training_args.bits == 16:\n            if training_args.bf16:\n                model.to(torch.bfloat16)\n            if training_args.fp16:\n                model.to(torch.float16)\n        rank0_print(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    if 'mpt' in model_args.model_name_or_path:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\"\n        )\n    else:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\",\n            use_fast=False,\n        )\n\n    if model_args.version == \"v0\":\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=\"[PAD]\"),\n                tokenizer=tokenizer,\n                model=model,\n            )\n    elif model_args.version == \"v0.5\":\n        tokenizer.pad_token = tokenizer.unk_token\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n        if model_args.version in conversation_lib.conv_templates:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]\n        else:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[\"vicuna_v1\"]\n\n    if model_args.vision_tower is not None:\n        model.get_model().initialize_vision_modules(\n            model_args=model_args,\n            fsdp=training_args.fsdp\n        )\n        \n        vision_tower = model.get_vision_tower()\n        vision_tower.to(dtype=torch.float16, device=training_args.device)\n\n        data_args.image_processor = vision_tower.image_processor\n        data_args.is_multimodal = True\n\n        model.config.image_aspect_ratio = data_args.image_aspect_ratio\n        model.config.image_grid_pinpoints = data_args.image_grid_pinpoints\n\n        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        if model_args.tune_mm_mlp_adapter or training_args.dbg:\n            model.requires_grad_(False)\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = True\n\n        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter\n        if training_args.freeze_mm_mlp_adapter:\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = False\n\n        if training_args.bits in [4, 8]:\n            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)\n\n        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        training_args.use_im_start_end = model_args.mm_use_im_start_end\n        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token\n        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)\n    cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size\n    model.initialize_seg_modules(\n        cfg=cfg\n    )\n\n    if training_args.bits in [4, 8]:\n        from peft.tuners.lora import LoraLayer\n        for name, module in model.named_modules():\n            if isinstance(module, LoraLayer):\n                if training_args.bf16:\n                    module = module.to(torch.bfloat16)\n            if 'norm' in name:\n                module = module.to(torch.float32)\n            if 'lm_head' in name or 'embed_tokens' in name:\n                if hasattr(module, 'weight'):\n                    if training_args.bf16 and module.weight.dtype == torch.float32:\n                        module = module.to(torch.bfloat16)\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    print(model)\n    if model_args.load_model:\n        loaded_dict = dict()\n        if \"stage1\" in model_args.whole_model:\n            old_emb_in=model.get_input_embeddings().weight.clone()\n            old_emb_out=model.get_output_embeddings().weight.clone()\n        for model_file in os.listdir(model_args.whole_model):\n            if model_file.endswith('.bin') and model_file.startswith('pytorch_model'):\n                loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu'))\n        model.load_state_dict(loaded_dict, strict=False)\n        if \"stage1\" in model_args.whole_model:\n            with torch.no_grad():\n                model.get_input_embeddings().weight[:-3]=old_emb_in[:-3]\n                model.get_output_embeddings().weight[:-3]=old_emb_out[:-3]\n        print(loaded_dict.keys())\n\n    trainer = LLaVATrainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess),\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    model.config.use_cache = True\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/train/train_joint_2st_interactive_refcoco_coco_instruction.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\nreplace_llama_attn_with_flash_attn()\nimport os\nimport copy\nfrom dataclasses import dataclass, field\nimport json\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport torch\n\nimport transformers\n\nfrom llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\nfrom torch.utils.data import Dataset\nfrom llava.train.llava_trainer_joint_train import LLaVATrainer\n\nfrom llava import conversation as conversation_lib\nfrom llava.model import *\nfrom llava.mm_utils import tokenizer_image_token,tokenizer_image_token_inter\n\nfrom PIL import Image\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    whole_model: Optional[str] = field(default=\"facebook/opt-125m\")\n    version: Optional[str] = field(default=\"v0\")\n    freeze_backbone: bool = field(default=False)\n    tune_mm_mlp_adapter: bool = field(default=False)\n    tune_prompt_adapter: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=None)\n    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)\n    mm_use_im_start_end: bool = field(default=False)\n    load_model: bool = field(default=False)\n    mm_use_im_patch_token: bool = field(default=True)\n    mm_vision_select_feature: Optional[str] = field(default=\"patch\")\n    opt: Optional[str] = field(default=\"\")\n    config_file_gd: Optional[str] = field(default=\"\")\n    config_file_it: Optional[str] = field(default=\"\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n    lazy_preprocess: bool = False\n    is_multimodal: bool = False\n    image_folder: Optional[str] = field(default=None)\n    image_aspect_ratio: str = 'square'\n    image_grid_pinpoints: Optional[str] = field(default=None)\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    freeze_mm_mlp_adapter: bool = field(default=False)\n    mpt_attn_impl: Optional[str] = field(default=\"triton\")\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    lora_enable: bool = False\n    new_tokens: bool = True\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    dbg: bool = False\n    load_optimizer_states: bool = True\n    load_lr_scheduler_states: bool = True\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    state_dict = trainer.model.state_dict()\n    if trainer.args.should_save:\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict\n        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict,\n    tokenizer: transformers.PreTrainedTokenizer,\n    model: transformers.PreTrainedModel,\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef _tokenize_fn(strings: Sequence[str],\n                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n    \"\"\"Tokenize a list of strings.\"\"\"\n    tokenized_list = [\n        tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ) for text in strings\n    ]\n    input_ids = labels = [\n        tokenized.input_ids[0] for tokenized in tokenized_list\n    ]\n    input_ids_lens = labels_lens = [\n        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()\n        for tokenized in tokenized_list\n    ]\n    return dict(\n        input_ids=input_ids,\n        labels=labels,\n        input_ids_lens=input_ids_lens,\n        labels_lens=labels_lens,\n    )\n\n\ndef _mask_targets(target, tokenized_lens, speakers):\n    # cur_idx = 0\n    cur_idx = tokenized_lens[0]\n    tokenized_lens = tokenized_lens[1:]\n    target[:cur_idx] = IGNORE_INDEX\n    for tokenized_len, speaker in zip(tokenized_lens, speakers):\n        if speaker == \"human\":\n            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n        cur_idx += tokenized_len\n\n\ndef _add_speaker_and_signal(header, source, get_conversation=True):\n    \"\"\"Add speaker and start/end signal on each round.\"\"\"\n    BEGIN_SIGNAL = \"### \"\n    END_SIGNAL = \"\\n\"\n    conversation = header\n    for sentence in source:\n        from_str = sentence[\"from\"]\n        if from_str.lower() == \"human\":\n            from_str = conversation_lib.default_conversation.roles[0]\n        elif from_str.lower() == \"gpt\":\n            from_str = conversation_lib.default_conversation.roles[1]\n        else:\n            from_str = 'unknown'\n        sentence[\"value\"] = (BEGIN_SIGNAL + from_str + \": \" +\n                             sentence[\"value\"] + END_SIGNAL)\n        if get_conversation:\n            conversation += sentence[\"value\"]\n    conversation += BEGIN_SIGNAL\n    return conversation\n\n\ndef preprocess_multimodal(\n    sources: Sequence[str],\n    data_args: DataArguments\n) -> Dict:\n    is_multimodal = data_args.is_multimodal\n    if not is_multimodal:\n        return sources\n\n    for source in sources:\n        for sentence in source:\n            if DEFAULT_IMAGE_TOKEN in sentence['value']:\n                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\\n' + sentence['value']\n                sentence['value'] = sentence['value'].strip()\n                if \"mmtag\" in conversation_lib.default_conversation.version:\n                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')\n            replace_token = DEFAULT_IMAGE_TOKEN\n            if data_args.mm_use_im_start_end:\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n            sentence[\"value\"] = sentence[\"value\"].replace(DEFAULT_IMAGE_TOKEN, replace_token)\n\n    return sources\n\n\ndef preprocess_llama_2(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2\n\n    # Mask targets\n    sep = \"[/INST] \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_v1(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt().replace(\"<obj>\", \"% \"))\n\n    # Tokenize conversations\n\n    if has_image:\n        input_ids = torch.stack([tokenizer_image_token_inter(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    else:\n        input_ids = tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n    targets = input_ids.clone()\n\n    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if has_image:\n                round_len = len(tokenizer_image_token(rou, tokenizer))\n                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2\n            else:\n                round_len = len(tokenizer(rou).input_ids)\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_mpt(\n    sources,\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    conv = conversation_lib.default_conversation.copy()\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    # Tokenize conversations\n    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n    targets = input_ids.clone()\n    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n    # Mask targets\n    sep = conv.sep + conv.roles[1]\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        rounds = conversation.split(conv.sep)\n        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n        for conv_idx in range(3, len(rounds), 2):\n            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_INDEX\n        for i, rou in enumerate(re_rounds):\n            if rou == \"\":\n                break\n\n            parts = rou.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))\n            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n            cur_len += round_len\n        target[cur_len:] = IGNORE_INDEX\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_INDEX\n                print(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n    )\n\n\ndef preprocess_plain(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n) -> Dict:\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        assert len(source) == 2\n        assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n        source[0]['value'] = DEFAULT_IMAGE_TOKEN\n        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep\n        conversations.append(conversation)\n    # tokenize conversations\n    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))\n        target[:tokenized_len] = IGNORE_INDEX\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\ndef preprocess(\n    sources: Sequence[str],\n    tokenizer: transformers.PreTrainedTokenizer,\n    has_image: bool = False\n) -> Dict:\n    \"\"\"\n    Given a list of sources, each is a conversation list. This transform:\n    1. Add signal '### ' at the beginning each sentence, with end signal '\\n';\n    2. Concatenate conversations together;\n    3. Tokenize the concatenated conversation;\n    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.\n    \"\"\"\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:\n        return preprocess_plain(sources, tokenizer)\n    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:\n        return preprocess_llama_2(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version.startswith(\"v1\"):\n        return preprocess_v1(sources, tokenizer, has_image=has_image)\n    if conversation_lib.default_conversation.version == \"mpt\":\n        return preprocess_mpt(sources, tokenizer)\n    # add end signal and concatenate together\n    conversations = []\n    for source in sources:\n        header = f\"{conversation_lib.default_conversation.system}\\n\\n\"\n        conversation = _add_speaker_and_signal(header, source)\n        conversations.append(conversation)\n    # tokenize conversations\n    def get_tokenize_len(prompts):\n        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]\n\n    if has_image:\n        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]\n    else:\n        conversations_tokenized = _tokenize_fn(conversations, tokenizer)\n        input_ids = conversations_tokenized[\"input_ids\"]\n\n    targets = copy.deepcopy(input_ids)\n    for target, source in zip(targets, sources):\n        if has_image:\n            tokenized_lens = get_tokenize_len([header] + [s[\"value\"] for s in source])\n        else:\n            tokenized_lens = _tokenize_fn([header] + [s[\"value\"] for s in source], tokenizer)[\"input_ids_lens\"]\n        speakers = [sentence[\"from\"] for sentence in source]\n        _mask_targets(target, tokenized_lens, speakers)\n\n    return dict(input_ids=input_ids, labels=targets)\n\n\nclass LazySupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(self, data_path: str,\n                 tokenizer: transformers.PreTrainedTokenizer,\n                 data_args: DataArguments):\n        super(LazySupervisedDataset, self).__init__()\n        list_data_dict = json.load(open(data_path, \"r\"))\n\n        rank0_print(\"Formatting inputs...Skip in lazy mode\")\n        self.tokenizer = tokenizer\n        self.list_data_dict = list_data_dict\n        self.data_args = data_args\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        try:\n            sources = self.list_data_dict[i]\n            if isinstance(i, int):\n                sources = [sources]\n            assert len(sources) == 1, \"Don't know why it is wrapped to a list\"  # FIXME\n            if 'image' in sources[0]:\n                image_file = self.list_data_dict[i]['image']\n                image_folder = self.data_args.image_folder\n                processor = self.data_args.image_processor\n                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')\n                if self.data_args.image_aspect_ratio == 'pad':\n                    def expand2square(pil_img, background_color):\n                        width, height = pil_img.size\n                        if width == height:\n                            return pil_img\n                        elif width > height:\n                            result = Image.new(pil_img.mode, (width, width), background_color)\n                            result.paste(pil_img, (0, (width - height) // 2))\n                            return result\n                        else:\n                            result = Image.new(pil_img.mode, (height, height), background_color)\n                            result.paste(pil_img, ((height - width) // 2, 0))\n                            return result\n                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                else:\n                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n                sources = preprocess_multimodal(\n                    copy.deepcopy([e[\"conversations\"] for e in sources]),\n                    self.data_args)\n            else:\n                sources = copy.deepcopy([e[\"conversations\"] for e in sources])\n            data_dict = preprocess(\n                sources,\n                self.tokenizer,\n                has_image=('image' in self.list_data_dict[i]))\n            if isinstance(i, int):\n                data_dict = dict(input_ids=data_dict[\"input_ids\"][0],\n                                 labels=data_dict[\"labels\"][0])\n\n            # image exist in the data\n            if 'image' in self.list_data_dict[i]:\n                data_dict['image'] = image\n            elif self.data_args.is_multimodal:\n                # image does not exist in the data, but the model is multimodal\n                crop_size = self.data_args.image_processor.crop_size\n                data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])\n            return data_dict\n        except Exception:\n            print(self.list_data_dict[i], \"failed\")\n            return self.__getitem__(i + 1)\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple([instance[key] for instance in instances]\n                                  for key in (\"input_ids\", \"labels\"))\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n        labels = torch.nn.utils.rnn.pad_sequence(labels,\n                                                 batch_first=True,\n                                                 padding_value=IGNORE_INDEX)\n        input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        labels = labels[:, :self.tokenizer.model_max_length]\n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        )\n\n        if 'image' in instances[0]:\n            images = [instance['image'] for instance in instances]\n            if all(x is not None and x.shape == images[0].shape for x in images):\n                batch['images'] = torch.stack(images)\n            else:\n                batch['images'] = images\n\n        return batch\n\n@dataclass\nclass DataCollatorForSupervisedDatasetEmpty(object):\n    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        return instances\n        # input_ids, labels = tuple([instance[key] for instance in instances]\n        #                           for key in (\"input_ids\", \"labels\"))\n        # input_ids = torch.nn.utils.rnn.pad_sequence(\n        #     input_ids,\n        #     batch_first=True,\n        #     padding_value=self.tokenizer.pad_token_id)\n        # labels = torch.nn.utils.rnn.pad_sequence(labels,\n        #                                          batch_first=True,\n        #                                          padding_value=IGNORE_INDEX)\n        # input_ids = input_ids[:, :self.tokenizer.model_max_length]\n        # labels = labels[:, :self.tokenizer.model_max_length]\n        # batch = dict(\n        #     input_ids=input_ids,\n        #     labels=labels,\n        #     attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n        # )\n        #\n        # if 'image' in instances[0]:\n        #     images = [instance['image'] for instance in instances]\n        #     if all(x is not None and x.shape == images[0].shape for x in images):\n        #         batch['images'] = torch.stack(images)\n        #     else:\n        #         batch['images'] = images\n        #\n        # return batch\n\ndef make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,\n                                data_args) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,\n                                data_path=data_args.data_path,\n                                data_args=data_args)\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)\n\nfrom detectron2.config import LazyConfig, instantiate\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg1 = LazyConfig.load(args.config_file_gd)\n    cfg2 = LazyConfig.load(args.config_file_it)\n    # import pdb;pdb.set_trace()\n    opt1,opt2=args.opt.split(';')\n    opt1=opt1.split(',')\n    opt2=opt2.split(',')\n    cfg1 = LazyConfig.apply_overrides(cfg1, opt1)\n    cfg2 = LazyConfig.apply_overrides(cfg2, opt2)\n    # cfg.freeze()\n    # default_setup(cfg, args)\n    # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name=\"maskdino\")\n    return cfg1,cfg2\n\ndef train():\n    global local_rank\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n    cfg,cfg2=setup(model_args)\n    bnb_model_from_pretrained_args = {}\n    if training_args.bits in [4, 8]:\n        from transformers import BitsAndBytesConfig\n        bnb_model_from_pretrained_args.update(dict(\n            device_map={\"\": training_args.device},\n            load_in_4bit=training_args.bits == 4,\n            load_in_8bit=training_args.bits == 8,\n            quantization_config=BitsAndBytesConfig(\n                load_in_4bit=training_args.bits == 4,\n                load_in_8bit=training_args.bits == 8,\n                llm_int8_threshold=6.0,\n                llm_int8_has_fp16_weight=False,\n                bnb_4bit_compute_dtype=compute_dtype,\n                bnb_4bit_use_double_quant=training_args.double_quant,\n                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}\n            )\n        ))\n\n    if model_args.vision_tower is not None:\n        if 'mpt' in model_args.model_name_or_path:\n            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\", trust_remote_code=True)\n            config.attn_config['attn_impl'] = training_args.mpt_attn_impl\n            model = LlavaMPTForCausalLM.from_pretrained(\n                model_args.model_name_or_path,\n                config=config,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n        else:\n            model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(\n                model_args.model_name_or_path,\n                cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n                **bnb_model_from_pretrained_args\n            )\n    else:\n        model = transformers.LlamaForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            **bnb_model_from_pretrained_args\n        )\n    model.config.use_cache = False\n\n    if model_args.freeze_backbone:\n        model.model.requires_grad_(False)\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    if training_args.gradient_checkpointing:\n        if hasattr(model, \"enable_input_require_grads\"):\n            model.enable_input_require_grads()\n        else:\n            def make_inputs_require_grad(module, input, output):\n                output.requires_grad_(True)\n            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        if training_args.bits == 16:\n            if training_args.bf16:\n                model.to(torch.bfloat16)\n            if training_args.fp16:\n                model.to(torch.float16)\n        rank0_print(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    if 'mpt' in model_args.model_name_or_path:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\"\n        )\n    else:\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=\"/comp_robot/zhanghao/.cache/hugging_face/\",\n            model_max_length=training_args.model_max_length,\n            padding_side=\"right\",\n            use_fast=False,\n        )\n\n    if model_args.version == \"v0\":\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=\"[PAD]\"),\n                tokenizer=tokenizer,\n                model=model,\n            )\n    elif model_args.version == \"v0.5\":\n        tokenizer.pad_token = tokenizer.unk_token\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n        if model_args.version in conversation_lib.conv_templates:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]\n        else:\n            conversation_lib.default_conversation = conversation_lib.conv_templates[\"vicuna_v1\"]\n\n    if model_args.vision_tower is not None:\n        model.get_model().initialize_vision_modules(\n            model_args=model_args,\n            fsdp=training_args.fsdp\n        )\n        \n        vision_tower = model.get_vision_tower()\n        vision_tower.to(dtype=torch.float16, device=training_args.device)\n\n        data_args.image_processor = vision_tower.image_processor\n        data_args.is_multimodal = True\n\n        model.config.image_aspect_ratio = data_args.image_aspect_ratio\n        model.config.image_grid_pinpoints = data_args.image_grid_pinpoints\n\n        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter\n        model.config.tune_prompt_adapter = training_args.tune_prompt_adapter = model_args.tune_prompt_adapter\n\n        if model_args.tune_mm_mlp_adapter or training_args.dbg:\n            model.requires_grad_(False)\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = True\n\n\n        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter\n        if training_args.freeze_mm_mlp_adapter:\n            for p in model.get_model().mm_projector.parameters():\n                p.requires_grad = False\n\n        if training_args.bits in [4, 8]:\n            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)\n\n        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end\n        training_args.use_im_start_end = model_args.mm_use_im_start_end\n        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token\n        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)\n    cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size\n    model.initialize_seg_modules(\n        cfg=cfg\n    )\n    if model_args.tune_prompt_adapter:\n        model.requires_grad_(False)\n        model.freeze_seg_modules()\n    model.initialize_interactive_modules(cfg=cfg2,model_args=model_args)\n\n    if training_args.bits in [4, 8]:\n        from peft.tuners.lora import LoraLayer\n        for name, module in model.named_modules():\n            if isinstance(module, LoraLayer):\n                if training_args.bf16:\n                    module = module.to(torch.bfloat16)\n            if 'norm' in name:\n                module = module.to(torch.float32)\n            if 'lm_head' in name or 'embed_tokens' in name:\n                if hasattr(module, 'weight'):\n                    if training_args.bf16 and module.weight.dtype == torch.float32:\n                        module = module.to(torch.bfloat16)\n\n    data_module = make_supervised_data_module(tokenizer=tokenizer,\n                                              data_args=data_args)\n    print(model)\n    if model_args.load_model:\n        loaded_dict = dict()\n        if \"stage1\" in model_args.whole_model:\n            old_emb_in=model.get_input_embeddings().weight.clone()\n            old_emb_out=model.get_output_embeddings().weight.clone()\n        for model_file in os.listdir(model_args.whole_model):\n            if model_file.endswith('.bin') and model_file.startswith('pytorch_model'):\n                loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu'))\n\n        model.load_state_dict(loaded_dict, strict=False)\n        if \"stage1\" in model_args.whole_model:\n            with torch.no_grad():\n                model.get_input_embeddings().weight[:-3]=old_emb_in[:-3]\n                model.get_output_embeddings().weight[:-3]=old_emb_out[:-3]\n        print(loaded_dict.keys())\n    training_args.train_interactive = True\n    trainer = LLaVATrainer(model=model,\n                    tokenizer=tokenizer,\n                    args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess),\n                    **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    model.config.use_cache = True\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/train/train_mem.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.\n\n# Need to call this before importing transformers.\nfrom llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\n\nreplace_llama_attn_with_flash_attn()\n\nfrom llava.train.train import train\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "llava/utils.py",
    "content": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nimport requests\n\nfrom llava.constants import LOGDIR\n\nserver_error_msg = \"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**\"\nmoderation_msg = \"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.\"\n\nhandler = None\n\n\ndef build_logger(logger_name, logger_filename):\n    global handler\n\n    formatter = logging.Formatter(\n        fmt=\"%(asctime)s | %(levelname)s | %(name)s | %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n    )\n\n    # Set the format of root handlers\n    if not logging.getLogger().handlers:\n        logging.basicConfig(level=logging.INFO)\n    logging.getLogger().handlers[0].setFormatter(formatter)\n\n    # Redirect stdout and stderr to loggers\n    stdout_logger = logging.getLogger(\"stdout\")\n    stdout_logger.setLevel(logging.INFO)\n    sl = StreamToLogger(stdout_logger, logging.INFO)\n    sys.stdout = sl\n\n    stderr_logger = logging.getLogger(\"stderr\")\n    stderr_logger.setLevel(logging.ERROR)\n    sl = StreamToLogger(stderr_logger, logging.ERROR)\n    sys.stderr = sl\n\n    # Get logger\n    logger = logging.getLogger(logger_name)\n    logger.setLevel(logging.INFO)\n\n    # Add a file handler for all loggers\n    if handler is None:\n        os.makedirs(LOGDIR, exist_ok=True)\n        filename = os.path.join(LOGDIR, logger_filename)\n        handler = logging.handlers.TimedRotatingFileHandler(\n            filename, when='D', utc=True)\n        handler.setFormatter(formatter)\n\n        for name, item in logging.root.manager.loggerDict.items():\n            if isinstance(item, logging.Logger):\n                item.addHandler(handler)\n\n    return logger\n\n\nclass StreamToLogger(object):\n    \"\"\"\n    Fake file-like stream object that redirects writes to a logger instance.\n    \"\"\"\n    def __init__(self, logger, log_level=logging.INFO):\n        self.terminal = sys.stdout\n        self.logger = logger\n        self.log_level = log_level\n        self.linebuf = ''\n\n    def __getattr__(self, attr):\n        return getattr(self.terminal, attr)\n\n    def write(self, buf):\n        temp_linebuf = self.linebuf + buf\n        self.linebuf = ''\n        for line in temp_linebuf.splitlines(True):\n            # From the io.TextIOWrapper docs:\n            #   On output, if newline is None, any '\\n' characters written\n            #   are translated to the system default line separator.\n            # By default sys.stdout.write() expects '\\n' newlines and then\n            # translates them so this is still cross platform.\n            if line[-1] == '\\n':\n                self.logger.log(self.log_level, line.rstrip())\n            else:\n                self.linebuf += line\n\n    def flush(self):\n        if self.linebuf != '':\n            self.logger.log(self.log_level, self.linebuf.rstrip())\n        self.linebuf = ''\n\n\ndef disable_torch_init():\n    \"\"\"\n    Disable the redundant torch default initialization to accelerate model creation.\n    \"\"\"\n    import torch\n    setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n    setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n\n\ndef violates_moderation(text):\n    \"\"\"\n    Check whether the text violates OpenAI moderation API.\n    \"\"\"\n    url = \"https://api.openai.com/v1/moderations\"\n    headers = {\"Content-Type\": \"application/json\",\n               \"Authorization\": \"Bearer \" + os.environ[\"OPENAI_API_KEY\"]}\n    text = text.replace(\"\\n\", \"\")\n    data = \"{\" + '\"input\": ' + f'\"{text}\"' + \"}\"\n    data = data.encode(\"utf-8\")\n    try:\n        ret = requests.post(url, headers=headers, data=data, timeout=5)\n        flagged = ret.json()[\"results\"][0][\"flagged\"]\n    except requests.exceptions.RequestException as e:\n        flagged = False\n    except KeyError as e:\n        flagged = False\n\n    return flagged\n\n\ndef pretty_print_semaphore(semaphore):\n    if semaphore is None:\n        return \"None\"\n    return f\"Semaphore(value={semaphore._value}, locked={semaphore.locked()})\"\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"llava\"\nversion = \"1.0.1\"\ndescription = \"Towards GPT-4 like large language and visual assistant.\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n]\ndependencies = [\n    \"einops\", \"fastapi\", \"gradio==3.39.0\", \"markdown2[all]\", \"numpy\",\n    \"requests\", \"sentencepiece\", \"tokenizers>=0.12.1\",\n    \"torch\", \"torchvision\", \"uvicorn\", \"wandb\",\n    \"shortuuid\", \"httpx==0.24.0\",\n    \"deepspeed==0.9.5\",\n    \"peft==0.4.0\",\n    \"transformers==4.31.0\",\n    \"accelerate==0.21.0\",\n    \"bitsandbytes==0.41.0\",\n    \"scikit-learn==1.2.2\",\n    \"sentencepiece==0.1.99\",\n    \"einops==0.6.1\", \"einops-exts==0.0.4\", \"timm==0.6.13\",\n    \"gradio_client==0.7.0\"\n]\n\n[project.urls]\n\"Homepage\" = \"https://llava-vl.github.io\"\n\"Bug Tracker\" = \"https://github.com/haotian-liu/LLaVA/issues\"\n\n[tool.setuptools.packages.find]\nexclude = [\"assets*\", \"benchmark*\", \"docs\", \"dist*\", \"playground*\", \"scripts*\", \"tests*\"]\n\n[tool.wheel]\nexclude = [\"assets*\", \"benchmark*\", \"docs\", \"dist*\", \"playground*\", \"scripts*\", \"tests*\"]\n"
  },
  {
    "path": "scripts/convert_sqa_to_llava.py",
    "content": "import json\nimport os\nimport fire\nimport re\nfrom convert_sqa_to_llava_base_prompt import build_prompt_chatbot\n\n\ndef convert_to_llava(base_dir, split, prompt_format=\"QCM-LEPA\"):\n    split_indices = json.load(open(os.path.join(base_dir, \"pid_splits.json\")))[split]\n    problems = json.load(open(os.path.join(base_dir, \"problems.json\")))\n\n    split_problems = build_prompt_chatbot(\n        problems, split_indices, prompt_format,\n        use_caption=False, is_test=False)\n\n    target_format = []\n    for prob_id, (input, output) in split_problems.items():\n        if input.startswith('Question: '):\n            input = input.replace('Question: ', '')\n        if output.startswith('Answer: '):\n            output = output.replace('Answer: ', '')\n\n        raw_prob_data = problems[prob_id]\n        if raw_prob_data['image'] is None:\n            target_format.append({\n                \"id\": prob_id,\n                \"conversations\": [\n                    {'from': 'human', 'value': f\"{input}\"},\n                    {'from': 'gpt', 'value': f\"{output}\"},\n                ],\n            })\n\n        else:\n            target_format.append({\n                \"id\": prob_id,\n                \"image\": os.path.join(prob_id, raw_prob_data['image']),\n                \"conversations\": [\n                    {'from': 'human', 'value': f\"{input}\\n<image>\"},\n                    {'from': 'gpt', 'value': f\"{output}\"},\n                ],\n            })\n\n    print(f'Number of samples: {len(target_format)}')\n\n    with open(os.path.join(base_dir, f\"llava_{split}_{prompt_format}.json\"), \"w\") as f:\n        json.dump(target_format, f, indent=2)\n\n\ndef convert_to_jsonl(base_dir, split, prompt_format=\"QCM-LEPA\"):\n    split_indices = json.load(open(os.path.join(base_dir, \"pid_splits.json\")))[split]\n    problems = json.load(open(os.path.join(base_dir, \"problems.json\")))\n\n    split_problems = build_prompt_chatbot(\n        problems, split_indices, prompt_format,\n        use_caption=False, is_test=False)\n\n    writer = open(os.path.join(base_dir, f\"scienceqa_{split}_{prompt_format}.jsonl\"), \"w\")\n    for prob_id, (input, output) in split_problems.items():\n        if input.startswith('Question: '):\n            input = input.replace('Question: ', '')\n        if output.startswith('Answer: '):\n            output = output.replace('Answer: ', '')\n\n        raw_prob_data = problems[prob_id]\n        if raw_prob_data['image'] is None:\n            data = {\n                \"id\": prob_id,\n                \"instruction\": f\"{input}\",\n                \"output\": f\"{output}\",\n            }\n\n        else:\n            data = {\n                \"id\": prob_id,\n                \"image\": os.path.join(prob_id, raw_prob_data['image']),\n                \"instruction\": f\"{input}\\n<image>\",\n                \"output\": f\"{output}\",\n            }\n        writer.write(json.dumps(data) + '\\n')\n    writer.close()\n\n\ndef main(task, **kwargs):\n    globals()[task](**kwargs)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(main)\n"
  },
  {
    "path": "scripts/convert_sqa_to_llava_base_prompt.py",
    "content": "def get_question_text(problem):\n    question = problem['question']\n    return question\n\n\ndef get_context_text(problem, use_caption):\n    txt_context = problem['hint']\n    img_context = problem['caption'] if use_caption else \"\"\n    context = \" \".join([txt_context, img_context]).strip()\n    if context == \"\":\n        context = \"N/A\"\n    return context\n\n\ndef get_choice_text(probelm, options):\n    choices = probelm['choices']\n    choice_list = []\n    for i, c in enumerate(choices):\n        choice_list.append(\"({}) {}\".format(options[i], c))\n    choice_txt = \" \".join(choice_list)\n    #print(choice_txt)\n    return choice_txt\n\n\ndef get_answer(problem, options):\n    return options[problem['answer']]\n\n\ndef get_lecture_text(problem):\n    # \\\\n: GPT-3 can generate the lecture with more tokens.\n    lecture = problem['lecture'].replace(\"\\n\", \"\\\\n\")\n    return lecture\n\n\ndef get_solution_text(problem):\n    # \\\\n: GPT-3 can generate the solution with more tokens\n    solution = problem['solution'].replace(\"\\n\", \"\\\\n\")\n    return solution\n\n\ndef create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):\n\n    input_format, output_format = format.split(\"-\")\n\n    ## Inputs\n    if input_format == \"CQM\":\n        input = f\"Context: {context}\\nQuestion: {question}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCM\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\n\"\n    # upper bound experiment\n    elif input_format == \"QCML\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture}\\n\"\n    elif input_format == \"QCME\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {solution}\\n\"\n    elif input_format == \"QCMLE\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture} {solution}\\n\"\n\n    elif input_format == \"QCLM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {solution}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCLEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture} {solution}\\nOptions: {choice}\\n\"\n\n    # Outputs\n    if test_example:\n        output = \"Answer:\"\n    elif output_format == 'A':\n        output = f\"Answer: The answer is {answer}.\"\n\n    elif output_format == 'AL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution}\"\n    elif output_format == 'AE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture}\"\n    elif output_format == 'ALE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}\"\n    elif output_format == 'AEL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}\"\n\n    elif output_format == 'LA':\n        output = f\"Answer: {lecture} The answer is {answer}.\"\n    elif output_format == 'EA':\n        output = f\"Answer: {solution} The answer is {answer}.\"\n    elif output_format == 'LEA':\n        output = f\"Answer: {lecture} {solution} The answer is {answer}.\"\n    elif output_format == 'ELA':\n        output = f\"Answer: {solution} {lecture} The answer is {answer}.\"\n    elif output_format == 'LEPA':\n        output = ''\n        if len(lecture.strip()) > 0:\n            output += f\"LECTURE: {lecture}\\n\"\n        if len(solution.strip()) > 0:\n            output += f\"SOLUTION: {solution}\\n\"\n        output += '###\\n'\n        output += f\"ANSWER: {answer}.\"\n\n    input = input.replace(\"  \", \" \").strip()\n    output = output.replace(\"  \", \" \").strip()\n    if input.endswith(\"BECAUSE:\"):\n        input = input.replace(\"BECAUSE:\", \"\").strip()\n    if output.endswith(\"BECAUSE:\"):\n        output = output.replace(\"BECAUSE:\", \"\").strip()\n    return input, output\n\n\ndef create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):\n\n    input_format, output_format = format.split(\"-\")\n\n    ## Inputs\n    if input_format == \"CQM\":\n        input = f\"Context: {context}\\nQuestion: {question}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCM\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\n\"\n    # upper bound experiment\n    elif input_format == \"QCML\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture}\\n\"\n    elif input_format == \"QCME\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {solution}\\n\"\n    elif input_format == \"QCMLE\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture} {solution}\\n\"\n\n    elif input_format == \"QCLM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {solution}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCLEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture} {solution}\\nOptions: {choice}\\n\"\n\n    # Outputs\n    if test_example:\n        output = \"Answer:\"\n    elif output_format == 'A':\n        output = f\"Answer: The answer is {answer}.\"\n\n    elif output_format == 'AL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution}\"\n    elif output_format == 'AE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture}\"\n    elif output_format == 'ALE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}\"\n    elif output_format == 'AEL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}\"\n\n    elif output_format == 'LA':\n        output = f\"Answer: {lecture} The answer is {answer}.\"\n    elif output_format == 'EA':\n        output = f\"Answer: {solution} The answer is {answer}.\"\n    elif output_format == 'LEA':\n        output = f\"Answer: {lecture} {solution} The answer is {answer}.\"\n    elif output_format == 'ELA':\n        output = f\"Answer: {solution} {lecture} The answer is {answer}.\"\n\n    text = input + output\n    text = text.replace(\"  \", \" \").strip()\n    if text.endswith(\"BECAUSE:\"):\n        text = text.replace(\"BECAUSE:\", \"\").strip()\n    return text\n\n\n\ndef create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):\n\n    input_format, output_format = format.split(\"-\")\n\n    ## Inputs\n    if input_format == \"CQM\":\n        input = f\"Context: {context}\\nQuestion: {question}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCM\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\n\"\n    # upper bound experiment\n    elif input_format == \"QCML\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture}\\n\"\n    elif input_format == \"QCME\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {solution}\\n\"\n    elif input_format == \"QCMLE\":\n        input = f\"Question: {question}\\nContext: {context}\\nOptions: {choice}\\nBECAUSE: {lecture} {solution}\\n\"\n\n    elif input_format == \"QCLM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {solution}\\nOptions: {choice}\\n\"\n    elif input_format == \"QCLEM\":\n        input = f\"Question: {question}\\nContext: {context}\\nBECAUSE: {lecture} {solution}\\nOptions: {choice}\\n\"\n\n    # Outputs\n    if test_example:\n        output = \"Answer:\"\n    elif output_format == 'A':\n        output = f\"Answer: The answer is {answer}.\"\n\n    elif output_format == 'AL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution}\"\n    elif output_format == 'AE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture}\"\n    elif output_format == 'ALE':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}\"\n    elif output_format == 'AEL':\n        output = f\"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}\"\n\n    elif output_format == 'LA':\n        output = f\"Answer: {lecture} The answer is {answer}.\"\n    elif output_format == 'EA':\n        output = f\"Answer: {solution} The answer is {answer}.\"\n    elif output_format == 'LEA':\n        output = f\"Answer: {lecture} {solution} The answer is {answer}.\"\n    elif output_format == 'ELA':\n        output = f\"Answer: {solution} {lecture} The answer is {answer}.\"\n\n    input = input.replace(\"  \", \" \").strip()\n    output = output.replace(\"  \", \" \").strip()\n    if output.endswith(\"BECAUSE:\"):\n        output = output.replace(\"BECAUSE:\", \"\").strip()\n\n    user_prompt = {\"role\": \"user\", \"content\": f\"Can you explain {input}?\"}\n    assistant_prompt = {\"role\": \"assistant\", \"content\": f\"{output}\"}\n\n    return user_prompt, assistant_prompt\n\n\ndef build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=[\"A\", \"B\", \"C\", \"D\", \"E\"], is_test=False):\n    examples = {}\n\n    for qid in shot_qids:\n        question = get_question_text(problems[qid])\n        context = get_context_text(problems[qid], use_caption)\n        choice = get_choice_text(problems[qid], options)\n        answer = get_answer(problems[qid], options)\n        lecture = get_lecture_text(problems[qid]).replace('\\\\n', '\\n')\n        solution = get_solution_text(problems[qid]).replace('\\\\n', '\\n')\n\n        train_example = create_one_example_chatbot(prompt_format,\n                                           question,\n                                           context,\n                                           choice,\n                                           answer,\n                                           lecture,\n                                           solution,\n                                           test_example=is_test)\n        examples[qid] = train_example\n    return examples\n\n\ndef build_prompt(problems, shot_qids, test_qid, args):\n\n    examples = []\n\n    # n-shot training examples\n    for qid in shot_qids:\n        question = get_question_text(problems[qid])\n        context = get_context_text(problems[qid], args.use_caption)\n        choice = get_choice_text(problems[qid], args.options)\n        answer = get_answer(problems[qid], args.options)\n        lecture = get_lecture_text(problems[qid])\n        solution = get_solution_text(problems[qid])\n\n        train_example = create_one_example(args.prompt_format,\n                                           question,\n                                           context,\n                                           choice,\n                                           answer,\n                                           lecture,\n                                           solution,\n                                           test_example=False)\n        examples.append(train_example)\n\n    # test example\n    question = get_question_text(problems[test_qid])\n    context = get_context_text(problems[test_qid], args.use_caption)\n    choice = get_choice_text(problems[test_qid], args.options)\n    answer = get_answer(problems[test_qid], args.options)\n    lecture = get_lecture_text(problems[test_qid])\n    solution = get_solution_text(problems[test_qid])\n\n    test_example = create_one_example(args.prompt_format,\n                                      question,\n                                      context,\n                                      choice,\n                                      answer,\n                                      lecture,\n                                      solution,\n                                      test_example=True)\n    examples.append(test_example)\n\n    # create the prompt input\n    prompt_input = '\\n\\n'.join(examples)\n\n    return prompt_input\n\n\ndef build_prompt_gpt4(problems, shot_qids, test_qid, args):\n\n    prompt_array = [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}]\n\n    # n-shot training examples\n    for qid in shot_qids:\n        question = get_question_text(problems[qid])\n        context = get_context_text(problems[qid], args.use_caption)\n        choice = get_choice_text(problems[qid], args.options)\n        answer = get_answer(problems[qid], args.options)\n        lecture = get_lecture_text(problems[qid])\n        solution = get_solution_text(problems[qid])\n\n        user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,\n                                           question,\n                                           context,\n                                           choice,\n                                           answer,\n                                           lecture,\n                                           solution,\n                                           test_example=False)\n        prompt_array.append(user_prompt)\n        prompt_array.append(assistant_prompt)\n\n    # test example\n    question = get_question_text(problems[test_qid])\n    context = get_context_text(problems[test_qid], args.use_caption)\n    choice = get_choice_text(problems[test_qid], args.options)\n    answer = get_answer(problems[test_qid], args.options)\n    lecture = get_lecture_text(problems[test_qid])\n    solution = get_solution_text(problems[test_qid])\n\n    user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,\n                                      question,\n                                      context,\n                                      choice,\n                                      answer,\n                                      lecture,\n                                      solution,\n                                      test_example=True)\n    prompt_array.append(user_prompt)\n    prompt_array.append(assistant_prompt)\n\n    return prompt_array"
  },
  {
    "path": "scripts/finetune.sh",
    "content": "# Uncomment and set the following variables correspondingly to run this script:\n\n################## VICUNA ##################\nPROMPT_VERSION=v1\n# MODEL_VERSION=\"vicuna-v1-3-7b\"\n################## VICUNA ##################\n\n################## LLaMA-2 ##################\n# PROMPT_VERSION=\"llava_llama_2\"\n# MODEL_VERSION=\"llama-2-7b-chat\"\n################## LLaMA-2 ##################\nout_dir=output/llava_grounding_stage2\nload=output/llava_grounding_stage1\nmkdir -p $out_dir\necho $out_dir/log\nexport DATASET=datasets/\n\nnum_gpu=8\nbs=$(( 8 * $num_gpu ))\ndeepspeed llava/train/train_joint_2st.py \\\n    --deepspeed scripts/zero2.json \\\n    --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \\\n    --whole_model $load \\\n    --load_model True \\\n    --version $PROMPT_VERSION \\\n    --data_path datasets/llava/annotations/llava_instruct_150k.json \\\n    --image_folder datasets/coco/train2017/ \\\n    --vision_tower openai/clip-vit-large-patch14 \\\n    --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \\\n    --mm_vision_select_layer -2 \\\n    --mm_use_im_start_end False \\\n    --mm_use_im_patch_token False \\\n    --bf16 True \\\n    --output_dir $out_dir \\\n    --num_train_epochs 1 \\\n    --per_device_train_batch_size 8 \\\n    --per_device_eval_batch_size 4 \\\n    --gradient_accumulation_steps 1 \\\n    --evaluation_strategy \"no\" \\\n    --save_strategy \"steps\" \\\n    --save_steps 1000 \\\n    --save_total_limit 10 \\\n    --learning_rate 2e-5 \\\n    --weight_decay 0. \\\n    --warmup_ratio 0.03 \\\n    --lr_scheduler_type \"cosine\" \\\n    --logging_steps 1 \\\n    --tf32 True \\\n    --model_max_length 2400 \\\n    --gradient_checkpointing True \\\n    --dataloader_num_workers 4 \\\n    --lazy_preprocess True \\\n    --report_to wandb \\\n    --max_steps 10000 \\\n    --config_file \\\n    configs/openseed/openseed_swint_lang_joint_2st.yaml \\\n    --opt \\\n    MODEL.DECODER.WEIGHT_MULTIPLIER=0.1,MODEL.DECODER.COST_CLASS_WEIGHT=4.0,flickr.TRAIN.BATCH_SIZE_TOTAL=6,coco_instruct.TEST.BATCH_SIZE_TOTAL=${bs},coco_instruct.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/openseed_o365.pt \\\n    >> $out_dir/log 2>&1\n"
  },
  {
    "path": "scripts/finetune_visual_prompt.sh",
    "content": "# Uncomment and set the following variables correspondingly to run this script:\n\n################## VICUNA ##################\nPROMPT_VERSION=v1\n# MODEL_VERSION=\"vicuna-v1-3-7b\"\n################## VICUNA ##################\n\n################## LLaMA-2 ##################\n# PROMPT_VERSION=\"llava_llama_2\"\n# MODEL_VERSION=\"llama-2-7b-chat\"\n################## LLaMA-2 ##################\nout_dir=output/llava_stage2_visual_prompt\nload=output/llava_grounding_stage2/\nmkdir -p $out_dir\necho $out_dir/log\nexport DATASET=datasets/\n\nnum_gpu=8\nbs=$(( 8 * $num_gpu ))\ndeepspeed llava/train/train_joint_2st_interactive_refcoco_coco_instruction.py \\\n    --deepspeed scripts/zero2.json \\\n    --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \\\n    --whole_model $load \\\n    --load_model True \\\n    --version $PROMPT_VERSION \\\n    --data_path datasets/llava/annotations/llava_instruct_150k.json \\\n    --image_folder datasets/coco/train2017/ \\\n    --vision_tower openai/clip-vit-large-patch14 \\\n    --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \\\n    --mm_vision_select_layer -2 \\\n    --mm_use_im_start_end False \\\n    --tune_prompt_adapter True \\\n    --mm_use_im_patch_token False \\\n    --bf16 True \\\n    --output_dir $out_dir \\\n    --num_train_epochs 1 \\\n    --per_device_train_batch_size 2 \\\n    --per_device_eval_batch_size 4 \\\n    --gradient_accumulation_steps 1 \\\n    --evaluation_strategy \"no\" \\\n    --save_strategy \"steps\" \\\n    --save_steps 1000 \\\n    --save_total_limit 10 \\\n    --learning_rate 2e-5 \\\n    --weight_decay 0. \\\n    --warmup_ratio 0.03 \\\n    --lr_scheduler_type \"cosine\" \\\n    --logging_steps 1 \\\n    --tf32 True \\\n    --model_max_length 2400 \\\n    --gradient_checkpointing True \\\n    --dataloader_num_workers 4 \\\n    --lazy_preprocess True \\\n    --report_to wandb \\\n    --max_steps 20000 \\\n    --config_file_gd \\\n    configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml \\\n    --config_file_it \\\n    configs/semsam/visual_prompt_encoder.yaml \\\n    --opt \\\n    \"detach_seg=True,MODEL.DECODER.WEIGHT_MULTIPLIER=0.1,MODEL.DECODER.COST_CLASS_WEIGHT=4.0,flickr.TEST.BATCH_SIZE_TOTAL=${bs},flickr.TRAIN.BATCH_SIZE_TOTAL=${bs},coco_interactive.TRAIN.BATCH_SIZE_TOTAL=${bs},coco_instruct.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/openseed_o365.pt;MODEL.DECODER.WEIGHT_MULTIPLIER=0.2,coco_interactive.TEST.BATCH_SIZE_TOTAL=${bs},coco_interactive.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/visual_prompt_enc.pth\" \\\n    >> $out_dir/log 2>&1"
  },
  {
    "path": "scripts/merge_lora_weights.py",
    "content": "import argparse\nfrom llava.model.builder import load_pretrained_model\nfrom llava.mm_utils import get_model_name_from_path\n\n\ndef merge_lora(args):\n    model_name = get_model_name_from_path(args.model_path)\n    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')\n\n    model.save_pretrained(args.save_model_path)\n    tokenizer.save_pretrained(args.save_model_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-path\", type=str, required=True)\n    parser.add_argument(\"--model-base\", type=str, required=True)\n    parser.add_argument(\"--save-model-path\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    merge_lora(args)\n"
  },
  {
    "path": "scripts/pretrain_joint.sh",
    "content": "# Uncomment and set the following variables correspondingly to run this script:\n\n# MODEL_VERSION=vicuna-v1-3-7b\n# MODEL_VERSION=llama-2-7b-chat\n\n########### DO NOT CHANGE ###########\n########### USE THIS FOR BOTH ###########\nPROMPT_VERSION=v1\n########### DO NOT CHANGE ###########\nout_dir=output/llava_grounding_stage1\nmkdir -p $out_dir\necho $out_dir/log\nexport DATASET=datasets/\n\nn_gpu=4\n\ndeepspeed --include=localhost:1,2,3,7 llava/train/train_joint_1st.py \\\n    --deepspeed scripts/zero2.json \\\n    --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \\\n    --version $PROMPT_VERSION \\\n    --data_path datasets/llava/annotations/cap600k_brackets_all.json \\\n    --image_folder datasets/ConceptualCaptionsFiltered/ \\\n    --vision_tower openai/clip-vit-large-patch14 \\\n    --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \\\n    --tune_mm_mlp_adapter True \\\n    --mm_vision_select_layer -2 \\\n    --mm_use_im_start_end False \\\n    --mm_use_im_patch_token False \\\n    --bf16 True \\\n    --output_dir $out_dir \\\n    --max_steps 30000 \\\n    --num_train_epochs 1 \\\n    --per_device_train_batch_size 8 \\\n    --per_device_eval_batch_size 4 \\\n    --gradient_accumulation_steps 1 \\\n    --evaluation_strategy \"no\" \\\n    --save_strategy \"steps\" \\\n    --save_steps 1000 \\\n    --save_total_limit 100 \\\n    --learning_rate 1e-4 \\\n    --weight_decay 0. \\\n    --warmup_ratio 0.03 \\\n    --lr_scheduler_type \"cosine\" \\\n    --logging_steps 1 \\\n    --tf32 True \\\n    --model_max_length 2048 \\\n    --gradient_checkpointing True \\\n    --dataloader_num_workers 4 \\\n    --lazy_preprocess True \\\n    --report_to wandb \\\n    --config_file \\\n    configs/openseed/openseed_swint_lang_joint.yaml \\\n    --opt \\\n    flickr.TRAIN.BATCH_SIZE_TOTAL=8,COCO.TRAIN.BATCH_SIZE_TOTAL=24,MODEL.WEIGHTS=ckpts/openseed_o365.pt \\\n    >> $out_dir/log 2>&1\n"
  },
  {
    "path": "utils/Config.py",
    "content": "from fvcore.common.config import CfgNode as _CfgNode\n\nclass CfgNode(_CfgNode):\n    \"\"\"\n    The same as `fvcore.common.config.CfgNode`, but different in:\n\n    1. Use unsafe yaml loading by default.\n       Note that this may lead to arbitrary code execution: you must not\n       load a config file from untrusted sources before manually inspecting\n       the content of the file.\n    2. Support config versioning.\n       When attempting to merge an old config, it will convert the old config automatically.\n\n    .. automethod:: clone\n    .. automethod:: freeze\n    .. automethod:: defrost\n    .. automethod:: is_frozen\n    .. automethod:: load_yaml_with_base\n    .. automethod:: merge_from_list\n    .. automethod:: merge_from_other_cfg\n    \"\"\"\n\n    def merge_from_dict(self, dict):\n        pass\n    \nnode = CfgNode()"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/arguments.py",
    "content": "import yaml\nimport json\nimport argparse\nimport logging\n\nlogger = logging.getLogger(__name__)\n\n\ndef load_config_dict_to_opt(opt, config_dict):\n    \"\"\"\n    Load the key, value pairs from config_dict to opt, overriding existing values in opt\n    if there is any.\n    \"\"\"\n    if not isinstance(config_dict, dict):\n        raise TypeError(\"Config must be a Python dictionary\")\n    for k, v in config_dict.items():\n        k_parts = k.split('.')\n        pointer = opt\n        for k_part in k_parts[:-1]:\n            if k_part not in pointer:\n                pointer[k_part] = {}\n            pointer = pointer[k_part]\n            assert isinstance(pointer, dict), \"Overriding key needs to be inside a Python dict.\"\n        ori_value = pointer.get(k_parts[-1])\n        pointer[k_parts[-1]] = v\n        if ori_value:\n            logger.warning(f\"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}\")\n\n\ndef load_opt_from_config_files(conf_files):\n    \"\"\"\n    Load opt from the config files, settings in later files can override those in previous files.\n\n    Args:\n        conf_files (list): a list of config file paths\n\n    Returns:\n        dict: a dictionary of opt settings\n    \"\"\"\n    opt = {}\n    for conf_file in conf_files:\n        with open(conf_file, encoding='utf-8') as f:\n            config_dict = yaml.safe_load(f)\n\n        load_config_dict_to_opt(opt, config_dict)\n\n    return opt\n\n\ndef load_opt_command(args):\n    parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')\n    parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')\n    parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).')\n    parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')\n    parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {\"<PARAM_NAME_1>\": <PARAM_VALUE_1>, \"<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>\": <PARAM_VALUE_2>}. A key with \".\" updates the object in the corresponding nested dict. Remember to escape \" in command line.')\n    parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)\n\n    cmdline_args = parser.parse_args() if not args else parser.parse_args(args)\n\n    opt = load_opt_from_config_files(cmdline_args.conf_files)\n\n    if cmdline_args.config_overrides:\n        config_overrides_string = ' '.join(cmdline_args.config_overrides)\n        logger.warning(f\"Command line config overrides: {config_overrides_string}\")\n        config_dict = json.loads(config_overrides_string)\n        load_config_dict_to_opt(opt, config_dict)\n\n    if cmdline_args.overrides:\n        assert len(cmdline_args.overrides) % 2 == 0, \"overrides arguments is not paired, required: key value\"\n        keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]\n        vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]\n        vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]\n\n        types = []\n        for key in keys:\n            key = key.split('.')\n            ele = opt.copy()\n            while len(key) > 0:\n                ele = ele[key.pop(0)]\n            types.append(type(ele))\n        \n        config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}\n        load_config_dict_to_opt(opt, config_dict)\n\n    # combine cmdline_args into opt dictionary\n    for key, val in cmdline_args.__dict__.items():\n        if val is not None:\n            opt[key] = val\n\n    return opt, cmdline_args\n\n\ndef save_opt_to_json(opt, conf_file):\n    with open(conf_file, 'w', encoding='utf-8') as f:\n        json.dump(opt, f, indent=4)\n\n\ndef save_opt_to_yaml(opt, conf_file):\n    with open(conf_file, 'w', encoding='utf-8') as f:\n        yaml.dump(opt, f)\n"
  },
  {
    "path": "utils/constants.py",
    "content": "IMAGENET_CLASSES = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"projectile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"dark glasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n\nIMAGENET_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']\n\nIMAGENETS_919_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02790996', 'n02791124', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02834397', 'n02835271', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966687', 'n02971356', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03041632', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03075370', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03179701', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03649909', 'n03657121', 'n03658185', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03706229', 'n03709823', 'n03710193', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04086273', 'n04090263', 'n04099969', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04243546', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04376876', 'n04380533', 'n04389033', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04596742', 'n04597913', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06596364', 'n06794110', 'n06874185', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07930864', 'n07932039', 'n09229709', 'n09246464', 'n09256479', 'n09835506', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']\n\nIMAGENETS_FOLDER_NAMES = IMAGENETS_919_FOLDER_NAMES\n\nIMAGENETS_300_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01491361', 'n01494475', 'n01496331', 'n01518878', 'n01531178', 'n01532829', 'n01537544', 'n01608432', 'n01630670', 'n01632777', 'n01644373', 'n01644900', 'n01667114', 'n01675722', 'n01682714', 'n01685808', 'n01694178', 'n01695060', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01734418', 'n01744401', 'n01753488', 'n01770081', 'n01770393', 'n01773797', 'n01776313', 'n01817953', 'n01820546', 'n01855032', 'n01877812', 'n01882714', 'n01910747', 'n01914609', 'n01943899', 'n01980166', 'n01983481', 'n01984695', 'n01990800', 'n02011460', 'n02012849', 'n02013706', 'n02018795', 'n02058221', 'n02087046', 'n02088094', 'n02088632', 'n02090622', 'n02090721', 'n02091134', 'n02091244', 'n02093256', 'n02093754', 'n02094433', 'n02095570', 'n02097130', 'n02097209', 'n02097298', 'n02098413', 'n02100735', 'n02101556', 'n02102040', 'n02102177', 'n02104029', 'n02105412', 'n02107142', 'n02107312', 'n02110063', 'n02110958', 'n02111129', 'n02111500', 'n02111889', 'n02112706', 'n02113186', 'n02114855', 'n02119022', 'n02119789', 'n02120505', 'n02123394', 'n02123597', 'n02125311', 'n02127052', 'n02129604', 'n02133161', 'n02134418', 'n02165456', 'n02169497', 'n02177972', 'n02206856', 'n02256656', 'n02259212', 'n02268853', 'n02277742', 'n02280649', 'n02281406', 'n02321529', 'n02325366', 'n02326432', 'n02342885', 'n02396427', 'n02398521', 'n02415577', 'n02417914', 'n02447366', 'n02457408', 'n02480495', 'n02483362', 'n02488702', 'n02493793', 'n02494079', 'n02497673', 'n02504013', 'n02504458', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02669723', 'n02672831', 'n02690373', 'n02701002', 'n02708093', 'n02747177', 'n02769748', 'n02782093', 'n02783161', 'n02790996', 'n02793495', 'n02804610', 'n02808304', 'n02814533', 'n02835271', 'n02841315', 'n02859443', 'n02869837', 'n02870880', 'n02879718', 'n02892201', 'n02895154', 'n02917067', 'n02950826', 'n02951585', 'n02966687', 'n02978881', 'n02992529', 'n03000684', 'n03014705', 'n03018349', 'n03047690', 'n03075370', 'n03095699', 'n03109150', 'n03127925', 'n03133878', 'n03197337', 'n03201208', 'n03207941', 'n03223299', 'n03259280', 'n03271574', 'n03272562', 'n03291819', 'n03337140', 'n03376595', 'n03379051', 'n03393912', 'n03394916', 'n03404251', 'n03417042', 'n03443371', 'n03445777', 'n03452741', 'n03478589', 'n03482405', 'n03492542', 'n03494278', 'n03496892', 'n03532672', 'n03535780', 'n03538406', 'n03584829', 'n03590841', 'n03630383', 'n03633091', 'n03658185', 'n03673027', 'n03710193', 'n03743016', 'n03763968', 'n03764736', 'n03775546', 'n03781244', 'n03786901', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03794056', 'n03814906', 'n03825788', 'n03840681', 'n03874599', 'n03877845', 'n03888605', 'n03891251', 'n03903868', 'n03908714', 'n03929855', 'n03938244', 'n03956157', 'n03958227', 'n03976467', 'n03976657', 'n03991062', 'n04026417', 'n04033995', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04069434', 'n04070727', 'n04090263', 'n04099969', 'n04116512', 'n04118776', 'n04120489', 'n04179913', 'n04192698', 'n04201297', 'n04228054', 'n04229816', 'n04243546', 'n04254120', 'n04254680', 'n04254777', 'n04263257', 'n04265275', 'n04275548', 'n04277352', 'n04285008', 'n04311004', 'n04317175', 'n04335435', 'n04347754', 'n04371430', 'n04376876', 'n04380533', 'n04389033', 'n04399382', 'n04404412', 'n04429376', 'n04435653', 'n04447861', 'n04479046', 'n04483307', 'n04505470', 'n04507155', 'n04522168', 'n04540053', 'n04550184', 'n04552348', 'n04554684', 'n04557648', 'n04562935', 'n04579145', 'n04584207', 'n04591713', 'n04596742', 'n04606251', 'n04612504', 'n04613696', 'n06794110', 'n06874185', 'n07584110', 'n07614500', 'n07693725', 'n07697313', 'n07697537', 'n07711569', 'n07716906', 'n07720875', 'n07730033', 'n07742313', 'n07745940', 'n07749582', 'n07831146', 'n07880968', 'n07930864', 'n09256479', 'n12057211', 'n12768682', 'n12998815', 'n13037406', 'n13044778', 'n13054560']\n\nIMAGENETS_50_FOLDER_NAMES = ['n01443537', 'n01491361', 'n01531178', 'n01644373', 'n02104029', 'n02119022', 'n02123597', 'n02133161', 'n02165456', 'n02281406', 'n02325366', 'n02342885', 'n02396427', 'n02483362', 'n02504458', 'n02510455', 'n02690373', 'n02747177', 'n02783161', 'n02814533', 'n02859443', 'n02917067', 'n02992529', 'n03014705', 'n03047690', 'n03095699', 'n03197337', 'n03201208', 'n03445777', 'n03452741', 'n03584829', 'n03630383', 'n03775546', 'n03791053', 'n03874599', 'n03891251', 'n04026417', 'n04335435', 'n04380533', 'n04404412', 'n04447861', 'n04507155', 'n04522168', 'n04557648', 'n04562935', 'n04612504', 'n06794110', 'n07749582', 'n07831146', 'n12998815']\n\n\nIMAGENET_DEFAULT_TEMPLATES = [\n    '{}.',\n    'a bad photo of a {}.',\n    'a photo of many {}.',\n    'a sculpture of a {}.',\n    'a photo of the hard to see {}.',\n    'a low resolution photo of the {}.',\n    'a rendering of a {}.',\n    'graffiti of a {}.',\n    'a bad photo of the {}.',\n    'a cropped photo of the {}.',\n    'a tattoo of a {}.',\n    'the embroidered {}.',\n    'a photo of a hard to see {}.',\n    'a bright photo of a {}.',\n    'a photo of a clean {}.',\n    'a photo of a dirty {}.',\n    'a dark photo of the {}.',\n    'a drawing of a {}.',\n    'a photo of my {}.',\n    'the plastic {}.',\n    'a photo of the cool {}.',\n    'a close-up photo of a {}.',\n    'a black and white photo of the {}.',\n    'a painting of the {}.',\n    'a painting of a {}.',\n    'a pixelated photo of the {}.',\n    'a sculpture of the {}.',\n    'a bright photo of the {}.',\n    'a cropped photo of a {}.',\n    'a plastic {}.',\n    'a photo of the dirty {}.',\n    'a jpeg corrupted photo of a {}.',\n    'a blurry photo of the {}.',\n    'a photo of the {}.',\n    'a good photo of the {}.',\n    'a rendering of the {}.',\n    'a {} in a video game.',\n    'a photo of one {}.',\n    'a doodle of a {}.',\n    'a close-up photo of the {}.',\n    'a photo of a {}.',\n    'the origami {}.',\n    'the {} in a video game.',\n    'a sketch of a {}.',\n    'a doodle of the {}.',\n    'a origami {}.',\n    'a low resolution photo of a {}.',\n    'the toy {}.',\n    'a rendition of the {}.',\n    'a photo of the clean {}.',\n    'a photo of a large {}.',\n    'a rendition of a {}.',\n    'a photo of a nice {}.',\n    'a photo of a weird {}.',\n    'a blurry photo of a {}.',\n    'a cartoon {}.',\n    'art of a {}.',\n    'a sketch of the {}.',\n    'a embroidered {}.',\n    'a pixelated photo of a {}.',\n    'itap of the {}.',\n    'a jpeg corrupted photo of the {}.',\n    'a good photo of a {}.',\n    'a plushie {}.',\n    'a photo of the nice {}.',\n    'a photo of the small {}.',\n    'a photo of the weird {}.',\n    'the cartoon {}.',\n    'art of the {}.',\n    'a drawing of the {}.',\n    'a photo of the large {}.',\n    'a black and white photo of a {}.',\n    'the plushie {}.',\n    'a dark photo of a {}.',\n    'itap of a {}.',\n    'graffiti of the {}.',\n    'a toy {}.',\n    'itap of my {}.',\n    'a photo of a cool {}.',\n    'a photo of a small {}.',\n    'a tattoo of the {}.',\n]\n\nIMAGENET_SIMPLE_TEMPLATES = [\n    'a photo of {}.',\n]\nCOCO_INSTANCE_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']\nCOCO_PANOPTIC_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged']\n\nCOCO_IMAGENET_INDEX_PAIR = [[0, 591], [1, 444], [2, 479], [3, 671], [4, 405], [5, 779], [6, 466], [7, 717], [8, 472], [9, 920], [10, 686], [11, 920], [12, 704], [13, 708], [14, 134], [15, 282], [16, 215], [17, 291], [18, 349], [19, 341], [20, 366], [21, 294], [22, 340], [23, 286], [24, 414], [25, 879], [26, 636], [27, 906], [28, 519], [29, 162], [30, 795], [31, 671], [32, 852], [33, 405], [34, 805], [35, 560], [36, 671], [37, 671], [38, 752], [39, 907], [40, 907], [41, 787], [42, 792], [43, 792], [44, 792], [45, 538], [46, 954], [47, 948], [48, 697], [49, 950], [50, 937], [51, 936], [52, 934], [53, 963], [54, 931], [55, 415], [56, 423], [57, 831], [58, 738], [59, 520], [60, 532], [61, 896], [62, 851], [63, 620], [64, 674], [65, 761], [66, 508], [67, 487], [68, 447], [69, 827], [70, 859], [71, 896], [72, 760], [73, 454], [74, 892], [75, 883], [76, 591], [77, 850], [78, 584], [79, 696], [80, 879], [81, 672], [82, 839], [83, 519], [84, 896], [85, 854], [86, 799], [87, 858], [88, 947], [89, 953], [90, 792], [91, 425], [92, 862], [93, 475], [94, 916], [95, 721], [96, 538], [97, 703], [98, 705], [99, 979], [100, 888], [101, 538], [102, 774], [103, 390], [104, 519], [105, 803], [106, 708], [107, 672], [108, 879], [109, 825], [110, 825], [111, 858], [112, 825], [113, 898], [114, 904], [115, 904], [116, 947], [117, 862], [118, 858], [119, 862], [120, 648], [121, 736], [122, 904], [123, 708], [124, 970], [125, 936], [126, 792], [127, 549], [128, 712], [129, 647], [130, 972], [131, 904], [132, 824]]\n\nPASCAL_CLASSES = [\n    \"aeroplane\", \"bicycle\", \"bird\", \"boat\", \"bottle\", \"bus\", \"car\", \"cat\",\n    \"chair\", \"cow\", \"diningtable\", \"dog\", \"horse\", \"motorbike\", \"person\",\n    \"pottedplant\", \"sheep\", \"sofa\", \"train\", \"tvmonitor\"\n]\n\n# PASCAL_CLASSES = [\n#     \"airplane\", \"bicycle\", \"bird\", \"boat\", \"bottle\", \"bus\", \"car\", \"cat\",\n#     \"chair\", \"cow\", \"dining table\", \"dog\", \"horse\", \"motorcycle\", \"person\",\n#     \"potted plant\", \"sheep\", \"couch\", \"train\", \"tv\"\n# ]\n\nPASCAL_LABELS = [\n                [0, 0, 0],\n                [128, 0, 0],\n                [0, 128, 0],\n                [128, 128, 0],\n                [0, 0, 128],\n                [128, 0, 128],\n                [0, 128, 128],\n                [128, 128, 128],\n                [64, 0, 0],\n                [192, 0, 0],\n                [64, 128, 0],\n                [192, 128, 0],\n                [64, 0, 128],\n                [192, 0, 128],\n                [64, 128, 128],\n                [192, 128, 128],\n                [0, 64, 0],\n                [128, 64, 0],\n                [0, 192, 0],\n                [128, 192, 0],\n                [0, 64, 128],\n            ]\n\n# ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window ', 'grass', 'cabinet', 'sidewalk', 'person', 'ground', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'picture', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'closet', 'light', 'tub', 'rail', 'cushion', 'pedestal', 'box', 'column', 'signboard', 'dresser', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'pub', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'ceiling light', 'awning', 'street light', 'booth', 'tv', 'airplane', 'dirt road', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'transporter', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'cylinder', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'stair', 'storage tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'exhaust hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'tvmonitor', 'bulletin board', 'shower', 'heater', 'drinking glass', 'clock', 'flag']\n\n# ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window ', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'closet', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', 'dirt road', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag']\n\n\nADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'plane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag']\n\n\nCOCO_INTER_ADE = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'pavement', 'person', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'shelf', 'house', 'sea', 'mirror', 'rug', 'fence', 'rock', 'stone', 'sign', 'counter', 'sand', 'sink', 'refrigerator', 'stairs', 'table', 'pillow', 'door', 'river', 'bridge', 'blind', 'table', 'toilet', 'flower', 'book', 'bench', 'tree', 'chair', 'boat', 'bus', 'towel', 'light', 'truck', 'tv', 'dirt', 'bottle', 'counter', 'tent', 'oven', 'ball', 'food', 'microwave', 'bicycle', 'blanket', 'vase', 'traffic', 'light', 'glass', 'clock']\n\nPASCAL_CONTEXT_459 = ['accordion', 'aeroplane', 'air conditioner', 'antenna', 'artillery', 'ashtray', 'atrium', 'baby carriage', 'bag', 'ball', 'balloon', 'bamboo weaving', 'barrel', 'baseball bat', 'basket', 'basketball backboard', 'bathtub', 'bed', 'bedclothes', 'beer', 'bell', 'bench', 'bicycle', 'binoculars', 'bird', 'bird cage', 'bird feeder', 'bird nest', 'blackboard', 'board', 'boat', 'bone', 'book', 'bottle', 'bottle opener', 'bowl', 'box', 'bracelet', 'brick', 'bridge', 'broom', 'brush', 'bucket', 'building', 'bus', 'cabinet', 'cabinet door', 'cage', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camera lens', 'can', 'candle', 'candle holder', 'cap', 'car', 'card', 'cart', 'case', 'casette recorder', 'cash register', 'cat', 'cd', 'cd player', 'ceiling', 'cell phone', 'cello', 'chain', 'chair', 'chessboard', 'chicken', 'chopstick', 'clip', 'clippers', 'clock', 'closet', 'cloth', 'clothes tree', 'coffee', 'coffee machine', 'comb', 'computer', 'concrete', 'cone', 'container', 'control booth', 'controller', 'cooker', 'copying machine', 'coral', 'cork', 'corkscrew', 'counter', 'court', 'cow', 'crabstick', 'crane', 'crate', 'cross', 'crutch', 'cup', 'curtain', 'cushion', 'cutting board', 'dais', 'disc', 'disc case', 'dishwasher', 'dock', 'dog', 'dolphin', 'door', 'drainer', 'dray', 'drink dispenser', 'drinking machine', 'drop', 'drug', 'drum', 'drum kit', 'duck', 'dumbbell', 'earphone', 'earrings', 'egg', 'electric fan', 'electric iron', 'electric pot', 'electric saw', 'electronic keyboard', 'engine', 'envelope', 'equipment', 'escalator', 'exhibition booth', 'extinguisher', 'eyeglass', 'fan', 'faucet', 'fax machine', 'fence', 'ferris wheel', 'fire extinguisher', 'fire hydrant', 'fire place', 'fish', 'fish tank', 'fishbowl', 'fishing net', 'fishing pole', 'flag', 'flagstaff', 'flame', 'flashlight', 'floor', 'flower', 'fly', 'foam', 'food', 'footbridge', 'forceps', 'fork', 'forklift', 'fountain', 'fox', 'frame', 'fridge', 'frog', 'fruit', 'funnel', 'furnace', 'game controller', 'game machine', 'gas cylinder', 'gas hood', 'gas stove', 'gift box', 'glass', 'glass marble', 'globe', 'glove', 'goal', 'grandstand', 'grass', 'gravestone', 'ground', 'guardrail', 'guitar', 'gun', 'hammer', 'hand cart', 'handle', 'handrail', 'hanger', 'hard disk drive', 'hat', 'hay', 'headphone', 'heater', 'helicopter', 'helmet', 'holder', 'hook', 'horse', 'horse-drawn carriage', 'hot-air balloon', 'hydrovalve', 'ice', 'inflator pump', 'ipod', 'iron', 'ironing board', 'jar', 'kart', 'kettle', 'key', 'keyboard', 'kitchen range', 'kite', 'knife', 'knife block', 'ladder', 'ladder truck', 'ladle', 'laptop', 'leaves', 'lid', 'life buoy', 'light', 'light bulb', 'lighter', 'line', 'lion', 'lobster', 'lock', 'machine', 'mailbox', 'mannequin', 'map', 'mask', 'mat', 'match book', 'mattress', 'menu', 'metal', 'meter box', 'microphone', 'microwave', 'mirror', 'missile', 'model', 'money', 'monkey', 'mop', 'motorbike', 'mountain', 'mouse', 'mouse pad', 'musical instrument', 'napkin', 'net', 'newspaper', 'oar', 'ornament', 'outlet', 'oven', 'oxygen bottle', 'pack', 'pan', 'paper', 'paper box', 'paper cutter', 'parachute', 'parasol', 'parterre', 'patio', 'pelage', 'pen', 'pen container', 'pencil', 'person', 'photo', 'piano', 'picture', 'pig', 'pillar', 'pillow', 'pipe', 'pitcher', 'plant', 'plastic', 'plate', 'platform', 'player', 'playground', 'pliers', 'plume', 'poker', 'poker chip', 'pole', 'pool table', 'postcard', 'poster', 'pot', 'pottedplant', 'printer', 'projector', 'pumpkin', 'rabbit', 'racket', 'radiator', 'radio', 'rail', 'rake', 'ramp', 'range hood', 'receiver', 'recorder', 'recreational machines', 'remote control', 'road', 'robot', 'rock', 'rocket', 'rocking horse', 'rope', 'rug', 'ruler', 'runway', 'saddle', 'sand', 'saw', 'scale', 'scanner', 'scissors', 'scoop', 'screen', 'screwdriver', 'sculpture', 'scythe', 'sewer', 'sewing machine', 'shed', 'sheep', 'shell', 'shelves', 'shoe', 'shopping cart', 'shovel', 'sidecar', 'sidewalk', 'sign', 'signal light', 'sink', 'skateboard', 'ski', 'sky', 'sled', 'slippers', 'smoke', 'snail', 'snake', 'snow', 'snowmobiles', 'sofa', 'spanner', 'spatula', 'speaker', 'speed bump', 'spice container', 'spoon', 'sprayer', 'squirrel', 'stage', 'stair', 'stapler', 'stick', 'sticky note', 'stone', 'stool', 'stove', 'straw', 'stretcher', 'sun', 'sunglass', 'sunshade', 'surveillance camera', 'swan', 'sweeper', 'swim ring', 'swimming pool', 'swing', 'switch', 'table', 'tableware', 'tank', 'tap', 'tape', 'tarp', 'telephone', 'telephone booth', 'tent', 'tire', 'toaster', 'toilet', 'tong', 'tool', 'toothbrush', 'towel', 'toy', 'toy car', 'track', 'train', 'trampoline', 'trash bin', 'tray', 'tree', 'tricycle', 'tripod', 'trophy', 'truck', 'tube', 'turtle', 'tvmonitor', 'tweezers', 'typewriter', 'umbrella', 'unknown', 'vacuum cleaner', 'vending machine', 'video camera', 'video game console', 'video player', 'video tape', 'violin', 'wakeboard', 'wall', 'wallet', 'wardrobe', 'washing machine', 'watch', 'water', 'water dispenser', 'water pipe', 'water skate board', 'watermelon', 'whale', 'wharf', 'wheel', 'wheelchair', 'window', 'window blinds', 'wineglass', 'wire', 'wood', 'wool']\n\nPASCAL_CONTEXT_33 = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'sky', 'grass', 'ground', 'road', 'building', 'tree', 'water', 'mountain', 'wall', 'floor', 'track', 'keyboard', 'ceiling']\n\nPASCAL_CONTEXT_59 = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv monitor', 'bag', 'bed', 'bench', 'book', 'building', 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 'sign', 'plate', 'road', 'rock', 'shelves', 'side walk', 'sky', 'snow', 'bed clothes', 'track', 'tree', 'truck', 'wall', 'water', 'window', 'wood']\n\nSUN_RGBD_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']\n\nSCAN_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']\n\nSCAN_40 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop']\n\nSCAN_20 = [\"wall\", \"floor\", \"cabinet\", \"bed\", \"chair\", \"sofa\", \"table\", \"door\", \"window\", \"bookshelf\", \"picture\", \"counter\", \"desk\", \"curtain\", \"refrigerator\", \"shower curtain\", \"toilet\", \"sink\", \"bathtub\", \"otherfurniture\"]\n\nCITYSCAPES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']\n\nCITYSCAPES_THING = [\"person\", \"rider\", \"car\", \"truck\", \"bus\", \"train\", \"motorcycle\", \"bicycle\"]\n\nBDD_SEM = [\"road\", \"sidewalk\", \"building\", \"wall\", \"fence\", \"pole\", \"traffic light\", \"traffic sign\", \"vegetation\", \"terrain\", \"sky\", \"person\", \"rider\", \"car\", \"truck\", \"bus\", \"train\", \"motorcycle\", \"bicycle\"]\n\nBDD_PANO = ['dynamic', 'ego vehicle', 'ground', 'static', 'parking', 'rail track', 'road', 'sidewalk', 'bridge', 'building', 'fence', 'garage', 'guard rail', 'tunnel', 'wall', 'banner', 'billboard', 'lane divider', 'parking sign', 'pole', 'polegroup', 'street light', 'traffic cone', 'traffic device', 'traffic light', 'traffic sign', 'traffic sign frame', 'terrain', 'vegetation', 'sky', 'person', 'rider', 'bicycle', 'bus', 'car', 'caravan', 'motorcycle', 'trailer', 'train', 'truck']\n\nOBJECT365 = ['Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat', 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard', 'Gun', 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon', 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak', 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate', 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis ']\n\n\nOPENIMAGE = ['Tortoise', 'Container', 'Magpie', 'Sea turtle', 'Football', 'Ambulance', 'Ladder', 'Toothbrush', 'Syringe', 'Sink', 'Toy', 'Organ (Musical Instrument)', 'Cassette deck', 'Apple', 'Human eye', 'Cosmetics', 'Paddle', 'Snowman', 'Beer', 'Chopsticks', 'Human beard', 'Bird', 'Parking meter', 'Traffic light', 'Croissant', 'Cucumber', 'Radish', 'Towel', 'Doll', 'Skull', 'Washing machine', 'Glove', 'Tick', 'Belt', 'Sunglasses', 'Banjo', 'Cart', 'Ball', 'Backpack', 'Bicycle', 'Home appliance', 'Centipede', 'Boat', 'Surfboard', 'Boot', 'Headphones', 'Hot dog', 'Shorts', 'Fast food', 'Bus', 'Boy', 'Screwdriver', 'Bicycle wheel', 'Barge', 'Laptop', 'Miniskirt', 'Drill (Tool)', 'Dress', 'Bear', 'Waffle', 'Pancake', 'Brown bear', 'Woodpecker', 'Blue jay', 'Pretzel', 'Bagel', 'Tower', 'Teapot', 'Person', 'Bow and arrow', 'Swimwear', 'Beehive', 'Brassiere', 'Bee', 'Bat (Animal)', 'Starfish', 'Popcorn', 'Burrito', 'Chainsaw', 'Balloon', 'Wrench', 'Tent', 'Vehicle registration plate', 'Lantern', 'Toaster', 'Flashlight', 'Billboard', 'Tiara', 'Limousine', 'Necklace', 'Carnivore', 'Scissors', 'Stairs', 'Computer keyboard', 'Printer', 'Traffic sign', 'Chair', 'Shirt', 'Poster', 'Cheese', 'Sock', 'Fire hydrant', 'Land vehicle', 'Earrings', 'Tie', 'Watercraft', 'Cabinetry', 'Suitcase', 'Muffin', 'Bidet', 'Snack', 'Snowmobile', 'Clock', 'Medical equipment', 'Cattle', 'Cello', 'Jet ski', 'Camel', 'Coat', 'Suit', 'Desk', 'Cat', 'Bronze sculpture', 'Juice', 'Gondola', 'Beetle', 'Cannon', 'Computer mouse', 'Cookie', 'Office building', 'Fountain', 'Coin', 'Calculator', 'Cocktail', 'Computer monitor', 'Box', 'Stapler', 'Christmas tree', 'Cowboy hat', 'Hiking equipment', 'Studio couch', 'Drum', 'Dessert', 'Wine rack', 'Drink', 'Zucchini', 'Ladle', 'Human mouth', 'Dairy Product', 'Dice', 'Oven', 'Dinosaur', 'Ratchet (Device)', 'Couch', 'Cricket ball', 'Winter melon', 'Spatula', 'Whiteboard', 'Pencil sharpener', 'Door', 'Hat', 'Shower', 'Eraser', 'Fedora', 'Guacamole', 'Dagger', 'Scarf', 'Dolphin', 'Sombrero', 'Tin can', 'Mug', 'Tap', 'Harbor seal', 'Stretcher', 'Can opener', 'Goggles', 'Human body', 'Roller skates', 'Coffee cup', 'Cutting board', 'Blender', 'Plumbing fixture', 'Stop sign', 'Office supplies', 'Volleyball (Ball)', 'Vase', 'Slow cooker', 'Wardrobe', 'Coffee', 'Whisk', 'Paper towel', 'Personal care', 'Food', 'Sun hat', 'Tree house', 'Flying disc', 'Skirt', 'Gas stove', 'Salt and pepper shakers', 'Mechanical fan', 'Face powder', 'Fax', 'Fruit', 'French fries', 'Nightstand', 'Barrel', 'Kite', 'Tart', 'Treadmill', 'Fox', 'Flag', 'French horn', 'Window blind', 'Human foot', 'Golf cart', 'Jacket', 'Egg (Food)', 'Street light', 'Guitar', 'Pillow', 'Human leg', 'Isopod', 'Grape', 'Human ear', 'Power plugs and sockets', 'Panda', 'Giraffe', 'Woman', 'Door handle', 'Rhinoceros', 'Bathtub', 'Goldfish', 'Houseplant', 'Goat', 'Baseball bat', 'Baseball glove', 'Mixing bowl', 'Marine invertebrates', 'Kitchen utensil', 'Light switch', 'House', 'Horse', 'Stationary bicycle', 'Hammer', 'Ceiling fan', 'Sofa bed', 'Adhesive tape', 'Harp', 'Sandal', 'Bicycle helmet', 'Saucer', 'Harpsichord', 'Human hair', 'Heater', 'Harmonica', 'Hamster', 'Curtain', 'Bed', 'Kettle', 'Fireplace', 'Scale', 'Drinking straw', 'Insect', 'Hair dryer', 'Kitchenware', 'Indoor rower', 'Invertebrate', 'Food processor', 'Bookcase', 'Refrigerator', 'Wood-burning stove', 'Punching bag', 'Common fig', 'Cocktail shaker', 'Jaguar (Animal)', 'Golf ball', 'Fashion accessory', 'Alarm clock', 'Filing cabinet', 'Artichoke', 'Table', 'Tableware', 'Kangaroo', 'Koala', 'Knife', 'Bottle', 'Bottle opener', 'Lynx', 'Lavender (Plant)', 'Lighthouse', 'Dumbbell', 'Human head', 'Bowl', 'Humidifier', 'Porch', 'Lizard', 'Billiard table', 'Mammal', 'Mouse', 'Motorcycle', 'Musical instrument', 'Swim cap', 'Frying pan', 'Snowplow', 'Bathroom cabinet', 'Missile', 'Bust', 'Man', 'Waffle iron', 'Milk', 'Ring binder', 'Plate', 'Mobile phone', 'Baked goods', 'Mushroom', 'Crutch', 'Pitcher (Container)', 'Mirror', 'Personal flotation device', 'Table tennis racket', 'Pencil case', 'Musical keyboard', 'Scoreboard', 'Briefcase', 'Kitchen knife', 'Nail (Construction)', 'Tennis ball', 'Plastic bag', 'Oboe', 'Chest of drawers', 'Ostrich', 'Piano', 'Girl', 'Plant', 'Potato', 'Hair spray', 'Sports equipment', 'Pasta', 'Penguin', 'Pumpkin', 'Pear', 'Infant bed', 'Polar bear', 'Mixer', 'Cupboard', 'Jacuzzi', 'Pizza', 'Digital clock', 'Pig', 'Reptile', 'Rifle', 'Lipstick', 'Skateboard', 'Raven', 'High heels', 'Red panda', 'Rose', 'Rabbit', 'Sculpture', 'Saxophone', 'Shotgun', 'Seafood', 'Submarine sandwich', 'Snowboard', 'Sword', 'Picture frame', 'Sushi', 'Loveseat', 'Ski', 'Squirrel', 'Tripod', 'Stethoscope', 'Submarine', 'Scorpion', 'Segway', 'Training bench', 'Snake', 'Coffee table', 'Skyscraper', 'Sheep', 'Television', 'Trombone', 'Tea', 'Tank', 'Taco', 'Telephone', 'Torch', 'Tiger', 'Strawberry', 'Trumpet', 'Tree', 'Tomato', 'Train', 'Tool', 'Picnic basket', 'Cooking spray', 'Trousers', 'Bowling equipment', 'Football helmet', 'Truck', 'Measuring cup', 'Coffeemaker', 'Violin', 'Vehicle', 'Handbag', 'Paper cutter', 'Wine', 'Weapon', 'Wheel', 'Worm', 'Wok', 'Whale', 'Zebra', 'Auto part', 'Jug', 'Pizza cutter', 'Cream', 'Monkey', 'Lion', 'Bread', 'Platter', 'Chicken', 'Eagle', 'Helicopter', 'Owl', 'Duck', 'Turtle', 'Hippopotamus', 'Crocodile', 'Toilet', 'Toilet paper', 'Squid', 'Clothing', 'Footwear', 'Lemon', 'Spider', 'Deer', 'Frog', 'Banana', 'Rocket', 'Wine glass', 'Countertop', 'Tablet computer', 'Waste container', 'Swimming pool', 'Dog', 'Book', 'Elephant', 'Shark', 'Candle', 'Leopard', 'Axe', 'Hand dryer', 'Soap dispenser', 'Porcupine', 'Flower', 'Canary', 'Cheetah', 'Palm tree', 'Hamburger', 'Maple', 'Building', 'Fish', 'Lobster', 'Garden Asparagus', 'Furniture', 'Hedgehog', 'Airplane', 'Spoon', 'Otter', 'Bull', 'Oyster', 'Horizontal bar', 'Convenience store', 'Bomb', 'Bench', 'Ice cream', 'Caterpillar', 'Butterfly', 'Parachute', 'Orange', 'Antelope', 'Beaker', 'Moths and butterflies', 'Window', 'Closet', 'Castle', 'Jellyfish', 'Goose', 'Mule', 'Swan', 'Peach', 'Coconut', 'Seat belt', 'Raccoon', 'Chisel', 'Fork', 'Lamp', 'Camera', 'Squash (Plant)', 'Racket', 'Human face', 'Human arm', 'Vegetable', 'Diaper', 'Unicycle', 'Falcon', 'Chime', 'Snail', 'Shellfish', 'Cabbage', 'Carrot', 'Mango', 'Jeans', 'Flowerpot', 'Pineapple', 'Drawer', 'Stool', 'Envelope', 'Cake', 'Dragonfly', 'Common sunflower', 'Microwave oven', 'Honeycomb', 'Marine mammal', 'Sea lion', 'Ladybug', 'Shelf', 'Watch', 'Candy', 'Salad', 'Parrot', 'Handgun', 'Sparrow', 'Van', 'Grinder', 'Spice rack', 'Light bulb', 'Corded phone', 'Sports uniform', 'Tennis racket', 'Wall clock', 'Serving tray', 'Kitchen & dining room table', 'Dog bed', 'Cake stand', 'Cat furniture', 'Bathroom accessory', 'Facial tissue holder', 'Pressure cooker', 'Kitchen appliance', 'Tire', 'Ruler', 'Luggage and bags', 'Microphone', 'Broccoli', 'Umbrella', 'Pastry', 'Grapefruit', 'Band-aid', 'Animal', 'Bell pepper', 'Turkey', 'Lily', 'Pomegranate', 'Doughnut', 'Glasses', 'Human nose', 'Pen', 'Ant', 'Car', 'Aircraft', 'Human hand', 'Skunk', 'Teddy bear', 'Watermelon', 'Cantaloupe', 'Dishwasher', 'Flute', 'Balance beam', 'Sandwich', 'Shrimp', 'Sewing machine', 'Binoculars', 'Rays and skates', 'Ipod', 'Accordion', 'Willow', 'Crab', 'Crown', 'Seahorse', 'Perfume', 'Alpaca', 'Taxi', 'Canoe', 'Remote control', 'Wheelchair', 'Rugby ball', 'Armadillo', 'Maracas', 'Helmet']\n\nADE20K_847 = ['wall', 'building', 'sky', 'tree', 'road', 'floor', 'ceiling', 'bed', 'sidewalk', 'earth', 'cabinet', 'person', 'grass', 'windowpane', 'car', 'mountain', 'plant', 'table', 'chair', 'curtain', 'door', 'sofa', 'sea', 'painting', 'water', 'mirror', 'house', 'rug', 'shelf', 'armchair', 'fence', 'field', 'lamp', 'rock', 'seat', 'river', 'desk', 'bathtub', 'railing', 'signboard', 'cushion', 'path', 'work surface', 'stairs', 'column', 'sink', 'wardrobe', 'snow', 'refrigerator', 'base', 'bridge', 'blind', 'runway', 'cliff', 'sand', 'fireplace', 'pillow', 'screen door', 'toilet', 'skyscraper', 'grandstand', 'box', 'pool table', 'palm', 'double door', 'coffee table', 'counter', 'countertop', 'chest of drawers', 'kitchen island', 'boat', 'waterfall', 'stove', 'flower', 'bookcase', 'controls', 'book', 'stairway', 'streetlight', 'computer', 'bus', 'swivel chair', 'light', 'bench', 'case', 'towel', 'fountain', 'embankment', 'television receiver', 'van', 'hill', 'awning', 'poster', 'truck', 'airplane', 'pole', 'tower', 'court', 'ball', 'aircraft carrier', 'buffet', 'hovel', 'apparel', 'minibike', 'animal', 'chandelier', 'step', 'booth', 'bicycle', 'doorframe', 'sconce', 'pond', 'trade name', 'bannister', 'bag', 'traffic light', 'gazebo', 'escalator', 'land', 'board', 'arcade machine', 'eiderdown', 'bar', 'stall', 'playground', 'ship', 'ottoman', 'ashcan', 'bottle', 'cradle', 'pot', 'conveyer belt', 'train', 'stool', 'lake', 'tank', 'ice', 'basket', 'manhole', 'tent', 'canopy', 'microwave', 'barrel', 'dirt track', 'beam', 'dishwasher', 'plate', 'screen', 'ruins', 'washer', 'blanket', 'plaything', 'food', 'screen', 'oven', 'stage', 'beacon', 'umbrella', 'sculpture', 'aqueduct', 'container', 'scaffolding', 'hood', 'curb', 'roller coaster', 'horse', 'catwalk', 'glass', 'vase', 'central reservation', 'carousel', 'radiator', 'closet', 'machine', 'pier', 'fan', 'inflatable bounce game', 'pitch', 'paper', 'arcade', 'hot tub', 'helicopter', 'tray', 'partition', 'vineyard', 'bowl', 'bullring', 'flag', 'pot', 'footbridge', 'shower', 'bag', 'bulletin board', 'confessional booth', 'trunk', 'forest', 'elevator door', 'laptop', 'instrument panel', 'bucket', 'tapestry', 'platform', 'jacket', 'gate', 'monitor', 'telephone booth', 'spotlight', 'ring', 'control panel', 'blackboard', 'air conditioner', 'chest', 'clock', 'sand dune', 'pipe', 'vault', 'table football', 'cannon', 'swimming pool', 'fluorescent', 'statue', 'loudspeaker', 'exhibitor', 'ladder', 'carport', 'dam', 'pulpit', 'skylight', 'water tower', 'grill', 'display board', 'pane', 'rubbish', 'ice rink', 'fruit', 'patio', 'vending machine', 'telephone', 'net', 'backpack', 'jar', 'track', 'magazine', 'shutter', 'roof', 'banner', 'landfill', 'post', 'altarpiece', 'hat', 'arch', 'table game', 'bag', 'document', 'dome', 'pier', 'shanties', 'forecourt', 'crane', 'dog', 'piano', 'drawing', 'cabin', 'ad', 'amphitheater', 'monument', 'henhouse', 'cockpit', 'heater', 'windmill', 'pool', 'elevator', 'decoration', 'labyrinth', 'text', 'printer', 'mezzanine', 'mattress', 'straw', 'stalls', 'patio', 'billboard', 'bus stop', 'trouser', 'console table', 'rack', 'notebook', 'shrine', 'pantry', 'cart', 'steam shovel', 'porch', 'postbox', 'figurine', 'recycling bin', 'folding screen', 'telescope', 'deck chair', 'kennel', 'coffee maker', 'altar', 'fish', 'easel', 'artificial golf green', 'iceberg', 'candlestick', 'shower stall', 'television stand', 'wall socket', 'skeleton', 'grand piano', 'candy', 'grille door', 'pedestal', 'jersey', 'shoe', 'gravestone', 'shanty', 'structure', 'rocking chair', 'bird', 'place mat', 'tomb', 'big top', 'gas pump', 'lockers', 'cage', 'finger', 'bleachers', 'ferris wheel', 'hairdresser chair', 'mat', 'stands', 'aquarium', 'streetcar', 'napkin', 'dummy', 'booklet', 'sand trap', 'shop', 'table cloth', 'service station', 'coffin', 'drawer', 'cages', 'slot machine', 'balcony', 'volleyball court', 'table tennis', 'control table', 'shirt', 'merchandise', 'railway', 'parterre', 'chimney', 'can', 'tanks', 'fabric', 'alga', 'system', 'map', 'greenhouse', 'mug', 'barbecue', 'trailer', 'toilet tissue', 'organ', 'dishrag', 'island', 'keyboard', 'trench', 'basket', 'steering wheel', 'pitcher', 'goal', 'bread', 'beds', 'wood', 'file cabinet', 'newspaper', 'motorboat', 'rope', 'guitar', 'rubble', 'scarf', 'barrels', 'cap', 'leaves', 'control tower', 'dashboard', 'bandstand', 'lectern', 'switch', 'baseboard', 'shower room', 'smoke', 'faucet', 'bulldozer', 'saucepan', 'shops', 'meter', 'crevasse', 'gear', 'candelabrum', 'sofa bed', 'tunnel', 'pallet', 'wire', 'kettle', 'bidet', 'baby buggy', 'music stand', 'pipe', 'cup', 'parking meter', 'ice hockey rink', 'shelter', 'weeds', 'temple', 'patty', 'ski slope', 'panel', 'wallet', 'wheel', 'towel rack', 'roundabout', 'canister', 'rod', 'soap dispenser', 'bell', 'canvas', 'box office', 'teacup', 'trellis', 'workbench', 'valley', 'toaster', 'knife', 'podium', 'ramp', 'tumble dryer', 'fireplug', 'gym shoe', 'lab bench', 'equipment', 'rocky formation', 'plastic', 'calendar', 'caravan', 'check-in-desk', 'ticket counter', 'brush', 'mill', 'covered bridge', 'bowling alley', 'hanger', 'excavator', 'trestle', 'revolving door', 'blast furnace', 'scale', 'projector', 'soap', 'locker', 'tractor', 'stretcher', 'frame', 'grating', 'alembic', 'candle', 'barrier', 'cardboard', 'cave', 'puddle', 'tarp', 'price tag', 'watchtower', 'meters', 'light bulb', 'tracks', 'hair dryer', 'skirt', 'viaduct', 'paper towel', 'coat', 'sheet', 'fire extinguisher', 'water wheel', 'pottery', 'magazine rack', 'teapot', 'microphone', 'support', 'forklift', 'canyon', 'cash register', 'leaf', 'remote control', 'soap dish', 'windshield', 'cat', 'cue', 'vent', 'videos', 'shovel', 'eaves', 'antenna', 'shipyard', 'hen', 'traffic cone', 'washing machines', 'truck crane', 'cds', 'niche', 'scoreboard', 'briefcase', 'boot', 'sweater', 'hay', 'pack', 'bottle rack', 'glacier', 'pergola', 'building materials', 'television camera', 'first floor', 'rifle', 'tennis table', 'stadium', 'safety belt', 'cover', 'dish rack', 'synthesizer', 'pumpkin', 'gutter', 'fruit stand', 'ice floe', 'handle', 'wheelchair', 'mousepad', 'diploma', 'fairground ride', 'radio', 'hotplate', 'junk', 'wheelbarrow', 'stream', 'toll plaza', 'punching bag', 'trough', 'throne', 'chair desk', 'weighbridge', 'extractor fan', 'hanging clothes', 'dish', 'alarm clock', 'ski lift', 'chain', 'garage', 'mechanical shovel', 'wine rack', 'tramway', 'treadmill', 'menu', 'block', 'well', 'witness stand', 'branch', 'duck', 'casserole', 'frying pan', 'desk organizer', 'mast', 'spectacles', 'service elevator', 'dollhouse', 'hammock', 'clothes hanging', 'photocopier', 'notepad', 'golf cart', 'footpath', 'cross', 'baptismal font', 'boiler', 'skip', 'rotisserie', 'tables', 'water mill', 'helmet', 'cover curtain', 'brick', 'table runner', 'ashtray', 'street box', 'stick', 'hangers', 'cells', 'urinal', 'centerpiece', 'portable fridge', 'dvds', 'golf club', 'skirting board', 'water cooler', 'clipboard', 'camera', 'pigeonhole', 'chips', 'food processor', 'post box', 'lid', 'drum', 'blender', 'cave entrance', 'dental chair', 'obelisk', 'canoe', 'mobile', 'monitors', 'pool ball', 'cue rack', 'baggage carts', 'shore', 'fork', 'paper filer', 'bicycle rack', 'coat rack', 'garland', 'sports bag', 'fish tank', 'towel dispenser', 'carriage', 'brochure', 'plaque', 'stringer', 'iron', 'spoon', 'flag pole', 'toilet brush', 'book stand', 'water faucet', 'ticket office', 'broom', 'dvd', 'ice bucket', 'carapace', 'tureen', 'folders', 'chess', 'root', 'sewing machine', 'model', 'pen', 'violin', 'sweatshirt', 'recycling materials', 'mitten', 'chopping board', 'mask', 'log', 'mouse', 'grill', 'hole', 'target', 'trash bag', 'chalk', 'sticks', 'balloon', 'score', 'hair spray', 'roll', 'runner', 'engine', 'inflatable glove', 'games', 'pallets', 'baskets', 'coop', 'dvd player', 'rocking horse', 'buckets', 'bread rolls', 'shawl', 'watering can', 'spotlights', 'post-it', 'bowls', 'security camera', 'runner cloth', 'lock', 'alarm', 'side', 'roulette', 'bone', 'cutlery', 'pool balls', 'wheels', 'spice rack', 'plant pots', 'towel ring', 'bread box', 'video', 'funfair', 'breads', 'tripod', 'ironing board', 'skimmer', 'hollow', 'scratching post', 'tricycle', 'file box', 'mountain pass', 'tombstones', 'cooker', 'card game', 'golf bag', 'towel paper', 'chaise lounge', 'sun', 'toilet paper holder', 'rake', 'key', 'umbrella stand', 'dartboard', 'transformer', 'fireplace utensils', 'sweatshirts', 'cellular telephone', 'tallboy', 'stapler', 'sauna', 'test tube', 'palette', 'shopping carts', 'tools', 'push button', 'star', 'roof rack', 'barbed wire', 'spray', 'ear', 'sponge', 'racket', 'tins', 'eyeglasses', 'file', 'scarfs', 'sugar bowl', 'flip flop', 'headstones', 'laptop bag', 'leash', 'climbing frame', 'suit hanger', 'floor spotlight', 'plate rack', 'sewer', 'hard drive', 'sprinkler', 'tools box', 'necklace', 'bulbs', 'steel industry', 'club', 'jack', 'door bars', 'control panel', 'hairbrush', 'napkin holder', 'office', 'smoke detector', 'utensils', 'apron', 'scissors', 'terminal', 'grinder', 'entry phone', 'newspaper stand', 'pepper shaker', 'onions', 'central processing unit', 'tape', 'bat', 'coaster', 'calculator', 'potatoes', 'luggage rack', 'salt', 'street number', 'viewpoint', 'sword', 'cd', 'rowing machine', 'plug', 'andiron', 'pepper', 'tongs', 'bonfire', 'dog dish', 'belt', 'dumbbells', 'videocassette recorder', 'hook', 'envelopes', 'shower faucet', 'watch', 'padlock', 'swimming pool ladder', 'spanners', 'gravy boat', 'notice board', 'trash bags', 'fire alarm', 'ladle', 'stethoscope', 'rocket', 'funnel', 'bowling pins', 'valve', 'thermometer', 'cups', 'spice jar', 'night light', 'soaps', 'games table', 'slotted spoon', 'reel', 'scourer', 'sleeping robe', 'desk mat', 'dumbbell', 'hammer', 'tie', 'typewriter', 'shaker', 'cheese dish', 'sea star', 'racquet', 'butane gas cylinder', 'paper weight', 'shaving brush', 'sunglasses', 'gear shift', 'towel rail', 'adding machine']\n\nLVIS_CATEGORIES = ['aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'horned_cow', 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', 'food_processor', 'football_(American)', 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', 'headboard', 'headlight', 'headscarf', 'headset', 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', 'newsstand', 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', 'recliner', 'record_player', 'reflector', 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', 'rolling_pin', 'root_beer', 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', 'telephone_pole', 'telephoto_lens', 'television_camera', 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', 'washbasin', 'automatic_washer', 'watch', 'water_bottle', 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini']\n\n"
  },
  {
    "path": "utils/constants_ori.py",
    "content": "COCO_PANOPTIC_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged']\n\nADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag']\n\nADE20K_847 = ['wall', 'building', 'sky', 'tree', 'road', 'floor', 'ceiling', 'bed', 'sidewalk', 'earth', 'cabinet', 'person', 'grass', 'windowpane', 'car', 'mountain', 'plant', 'table', 'chair', 'curtain', 'door', 'sofa', 'sea', 'painting', 'water', 'mirror', 'house', 'rug', 'shelf', 'armchair', 'fence', 'field', 'lamp', 'rock', 'seat', 'river', 'desk', 'bathtub', 'railing', 'signboard', 'cushion', 'path', 'work surface', 'stairs', 'column', 'sink', 'wardrobe', 'snow', 'refrigerator', 'base', 'bridge', 'blind', 'runway', 'cliff', 'sand', 'fireplace', 'pillow', 'screen door', 'toilet', 'skyscraper', 'grandstand', 'box', 'pool table', 'palm', 'double door', 'coffee table', 'counter', 'countertop', 'chest of drawers', 'kitchen island', 'boat', 'waterfall', 'stove', 'flower', 'bookcase', 'controls', 'book', 'stairway', 'streetlight', 'computer', 'bus', 'swivel chair', 'light', 'bench', 'case', 'towel', 'fountain', 'embankment', 'television receiver', 'van', 'hill', 'awning', 'poster', 'truck', 'airplane', 'pole', 'tower', 'court', 'ball', 'aircraft carrier', 'buffet', 'hovel', 'apparel', 'minibike', 'animal', 'chandelier', 'step', 'booth', 'bicycle', 'doorframe', 'sconce', 'pond', 'trade name', 'bannister', 'bag', 'traffic light', 'gazebo', 'escalator', 'land', 'board', 'arcade machine', 'eiderdown', 'bar', 'stall', 'playground', 'ship', 'ottoman', 'ashcan', 'bottle', 'cradle', 'pot', 'conveyer belt', 'train', 'stool', 'lake', 'tank', 'ice', 'basket', 'manhole', 'tent', 'canopy', 'microwave', 'barrel', 'dirt track', 'beam', 'dishwasher', 'plate', 'screen', 'ruins', 'washer', 'blanket', 'plaything', 'food', 'screen', 'oven', 'stage', 'beacon', 'umbrella', 'sculpture', 'aqueduct', 'container', 'scaffolding', 'hood', 'curb', 'roller coaster', 'horse', 'catwalk', 'glass', 'vase', 'central reservation', 'carousel', 'radiator', 'closet', 'machine', 'pier', 'fan', 'inflatable bounce game', 'pitch', 'paper', 'arcade', 'hot tub', 'helicopter', 'tray', 'partition', 'vineyard', 'bowl', 'bullring', 'flag', 'pot', 'footbridge', 'shower', 'bag', 'bulletin board', 'confessional booth', 'trunk', 'forest', 'elevator door', 'laptop', 'instrument panel', 'bucket', 'tapestry', 'platform', 'jacket', 'gate', 'monitor', 'telephone booth', 'spotlight', 'ring', 'control panel', 'blackboard', 'air conditioner', 'chest', 'clock', 'sand dune', 'pipe', 'vault', 'table football', 'cannon', 'swimming pool', 'fluorescent', 'statue', 'loudspeaker', 'exhibitor', 'ladder', 'carport', 'dam', 'pulpit', 'skylight', 'water tower', 'grill', 'display board', 'pane', 'rubbish', 'ice rink', 'fruit', 'patio', 'vending machine', 'telephone', 'net', 'backpack', 'jar', 'track', 'magazine', 'shutter', 'roof', 'banner', 'landfill', 'post', 'altarpiece', 'hat', 'arch', 'table game', 'bag', 'document', 'dome', 'pier', 'shanties', 'forecourt', 'crane', 'dog', 'piano', 'drawing', 'cabin', 'ad', 'amphitheater', 'monument', 'henhouse', 'cockpit', 'heater', 'windmill', 'pool', 'elevator', 'decoration', 'labyrinth', 'text', 'printer', 'mezzanine', 'mattress', 'straw', 'stalls', 'patio', 'billboard', 'bus stop', 'trouser', 'console table', 'rack', 'notebook', 'shrine', 'pantry', 'cart', 'steam shovel', 'porch', 'postbox', 'figurine', 'recycling bin', 'folding screen', 'telescope', 'deck chair', 'kennel', 'coffee maker', 'altar', 'fish', 'easel', 'artificial golf green', 'iceberg', 'candlestick', 'shower stall', 'television stand', 'wall socket', 'skeleton', 'grand piano', 'candy', 'grille door', 'pedestal', 'jersey', 'shoe', 'gravestone', 'shanty', 'structure', 'rocking chair', 'bird', 'place mat', 'tomb', 'big top', 'gas pump', 'lockers', 'cage', 'finger', 'bleachers', 'ferris wheel', 'hairdresser chair', 'mat', 'stands', 'aquarium', 'streetcar', 'napkin', 'dummy', 'booklet', 'sand trap', 'shop', 'table cloth', 'service station', 'coffin', 'drawer', 'cages', 'slot machine', 'balcony', 'volleyball court', 'table tennis', 'control table', 'shirt', 'merchandise', 'railway', 'parterre', 'chimney', 'can', 'tanks', 'fabric', 'alga', 'system', 'map', 'greenhouse', 'mug', 'barbecue', 'trailer', 'toilet tissue', 'organ', 'dishrag', 'island', 'keyboard', 'trench', 'basket', 'steering wheel', 'pitcher', 'goal', 'bread', 'beds', 'wood', 'file cabinet', 'newspaper', 'motorboat', 'rope', 'guitar', 'rubble', 'scarf', 'barrels', 'cap', 'leaves', 'control tower', 'dashboard', 'bandstand', 'lectern', 'switch', 'baseboard', 'shower room', 'smoke', 'faucet', 'bulldozer', 'saucepan', 'shops', 'meter', 'crevasse', 'gear', 'candelabrum', 'sofa bed', 'tunnel', 'pallet', 'wire', 'kettle', 'bidet', 'baby buggy', 'music stand', 'pipe', 'cup', 'parking meter', 'ice hockey rink', 'shelter', 'weeds', 'temple', 'patty', 'ski slope', 'panel', 'wallet', 'wheel', 'towel rack', 'roundabout', 'canister', 'rod', 'soap dispenser', 'bell', 'canvas', 'box office', 'teacup', 'trellis', 'workbench', 'valley', 'toaster', 'knife', 'podium', 'ramp', 'tumble dryer', 'fireplug', 'gym shoe', 'lab bench', 'equipment', 'rocky formation', 'plastic', 'calendar', 'caravan', 'check-in-desk', 'ticket counter', 'brush', 'mill', 'covered bridge', 'bowling alley', 'hanger', 'excavator', 'trestle', 'revolving door', 'blast furnace', 'scale', 'projector', 'soap', 'locker', 'tractor', 'stretcher', 'frame', 'grating', 'alembic', 'candle', 'barrier', 'cardboard', 'cave', 'puddle', 'tarp', 'price tag', 'watchtower', 'meters', 'light bulb', 'tracks', 'hair dryer', 'skirt', 'viaduct', 'paper towel', 'coat', 'sheet', 'fire extinguisher', 'water wheel', 'pottery', 'magazine rack', 'teapot', 'microphone', 'support', 'forklift', 'canyon', 'cash register', 'leaf', 'remote control', 'soap dish', 'windshield', 'cat', 'cue', 'vent', 'videos', 'shovel', 'eaves', 'antenna', 'shipyard', 'hen', 'traffic cone', 'washing machines', 'truck crane', 'cds', 'niche', 'scoreboard', 'briefcase', 'boot', 'sweater', 'hay', 'pack', 'bottle rack', 'glacier', 'pergola', 'building materials', 'television camera', 'first floor', 'rifle', 'tennis table', 'stadium', 'safety belt', 'cover', 'dish rack', 'synthesizer', 'pumpkin', 'gutter', 'fruit stand', 'ice floe', 'handle', 'wheelchair', 'mousepad', 'diploma', 'fairground ride', 'radio', 'hotplate', 'junk', 'wheelbarrow', 'stream', 'toll plaza', 'punching bag', 'trough', 'throne', 'chair desk', 'weighbridge', 'extractor fan', 'hanging clothes', 'dish', 'alarm clock', 'ski lift', 'chain', 'garage', 'mechanical shovel', 'wine rack', 'tramway', 'treadmill', 'menu', 'block', 'well', 'witness stand', 'branch', 'duck', 'casserole', 'frying pan', 'desk organizer', 'mast', 'spectacles', 'service elevator', 'dollhouse', 'hammock', 'clothes hanging', 'photocopier', 'notepad', 'golf cart', 'footpath', 'cross', 'baptismal font', 'boiler', 'skip', 'rotisserie', 'tables', 'water mill', 'helmet', 'cover curtain', 'brick', 'table runner', 'ashtray', 'street box', 'stick', 'hangers', 'cells', 'urinal', 'centerpiece', 'portable fridge', 'dvds', 'golf club', 'skirting board', 'water cooler', 'clipboard', 'camera', 'pigeonhole', 'chips', 'food processor', 'post box', 'lid', 'drum', 'blender', 'cave entrance', 'dental chair', 'obelisk', 'canoe', 'mobile', 'monitors', 'pool ball', 'cue rack', 'baggage carts', 'shore', 'fork', 'paper filer', 'bicycle rack', 'coat rack', 'garland', 'sports bag', 'fish tank', 'towel dispenser', 'carriage', 'brochure', 'plaque', 'stringer', 'iron', 'spoon', 'flag pole', 'toilet brush', 'book stand', 'water faucet', 'ticket office', 'broom', 'dvd', 'ice bucket', 'carapace', 'tureen', 'folders', 'chess', 'root', 'sewing machine', 'model', 'pen', 'violin', 'sweatshirt', 'recycling materials', 'mitten', 'chopping board', 'mask', 'log', 'mouse', 'grill', 'hole', 'target', 'trash bag', 'chalk', 'sticks', 'balloon', 'score', 'hair spray', 'roll', 'runner', 'engine', 'inflatable glove', 'games', 'pallets', 'baskets', 'coop', 'dvd player', 'rocking horse', 'buckets', 'bread rolls', 'shawl', 'watering can', 'spotlights', 'post-it', 'bowls', 'security camera', 'runner cloth', 'lock', 'alarm', 'side', 'roulette', 'bone', 'cutlery', 'pool balls', 'wheels', 'spice rack', 'plant pots', 'towel ring', 'bread box', 'video', 'funfair', 'breads', 'tripod', 'ironing board', 'skimmer', 'hollow', 'scratching post', 'tricycle', 'file box', 'mountain pass', 'tombstones', 'cooker', 'card game', 'golf bag', 'towel paper', 'chaise lounge', 'sun', 'toilet paper holder', 'rake', 'key', 'umbrella stand', 'dartboard', 'transformer', 'fireplace utensils', 'sweatshirts', 'cellular telephone', 'tallboy', 'stapler', 'sauna', 'test tube', 'palette', 'shopping carts', 'tools', 'push button', 'star', 'roof rack', 'barbed wire', 'spray', 'ear', 'sponge', 'racket', 'tins', 'eyeglasses', 'file', 'scarfs', 'sugar bowl', 'flip flop', 'headstones', 'laptop bag', 'leash', 'climbing frame', 'suit hanger', 'floor spotlight', 'plate rack', 'sewer', 'hard drive', 'sprinkler', 'tools box', 'necklace', 'bulbs', 'steel industry', 'club', 'jack', 'door bars', 'control panel', 'hairbrush', 'napkin holder', 'office', 'smoke detector', 'utensils', 'apron', 'scissors', 'terminal', 'grinder', 'entry phone', 'newspaper stand', 'pepper shaker', 'onions', 'central processing unit', 'tape', 'bat', 'coaster', 'calculator', 'potatoes', 'luggage rack', 'salt', 'street number', 'viewpoint', 'sword', 'cd', 'rowing machine', 'plug', 'andiron', 'pepper', 'tongs', 'bonfire', 'dog dish', 'belt', 'dumbbells', 'videocassette recorder', 'hook', 'envelopes', 'shower faucet', 'watch', 'padlock', 'swimming pool ladder', 'spanners', 'gravy boat', 'notice board', 'trash bags', 'fire alarm', 'ladle', 'stethoscope', 'rocket', 'funnel', 'bowling pins', 'valve', 'thermometer', 'cups', 'spice jar', 'night light', 'soaps', 'games table', 'slotted spoon', 'reel', 'scourer', 'sleeping robe', 'desk mat', 'dumbbell', 'hammer', 'tie', 'typewriter', 'shaker', 'cheese dish', 'sea star', 'racquet', 'butane gas cylinder', 'paper weight', 'shaving brush', 'sunglasses', 'gear shift', 'towel rail', 'adding machine']\n\nSUN_RGBD_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']\n\nSCAN_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']\nSCAN_40 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop']\nSCAN_20 = [\"wall\", \"floor\", \"cabinet\", \"bed\", \"chair\", \"sofa\", \"table\", \"door\", \"window\", \"bookshelf\", \"picture\", \"counter\", \"desk\", \"curtain\", \"refrigerator\", \"shower curtain\", \"toilet\", \"sink\", \"bathtub\", \"otherfurniture\"]\n\nCITYSCAPES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']\nCITYSCAPES_THING = [\"person\", \"rider\", \"car\", \"truck\", \"bus\", \"train\", \"motorcycle\", \"bicycle\"]\n\nBDD_SEM = [\"road\", \"sidewalk\", \"building\", \"wall\", \"fence\", \"pole\", \"traffic light\", \"traffic sign\", \"vegetation\", \"terrain\", \"sky\", \"person\", \"rider\", \"car\", \"truck\", \"bus\", \"train\", \"motorcycle\", \"bicycle\"]\nBDD_PANO = ['dynamic', 'ego vehicle', 'ground', 'static', 'parking', 'rail track', 'road', 'sidewalk', 'bridge', 'building', 'fence', 'garage', 'guard rail', 'tunnel', 'wall', 'banner', 'billboard', 'lane divider', 'parking sign', 'pole', 'polegroup', 'street light', 'traffic cone', 'traffic device', 'traffic light', 'traffic sign', 'traffic sign frame', 'terrain', 'vegetation', 'sky', 'person', 'rider', 'bicycle', 'bus', 'car', 'caravan', 'motorcycle', 'trailer', 'train', 'truck']\n\nIMAGENET_CLASSES = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"projectile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"dark glasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\nIMAGENET_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']\nIMAGENET_DEFAULT_TEMPLATES = [\n    '{}.',\n    'a bad photo of a {}.',\n    'a photo of many {}.',\n    'a sculpture of a {}.',\n    'a photo of the hard to see {}.',\n    'a low resolution photo of the {}.',\n    'a rendering of a {}.',\n    'graffiti of a {}.',\n    'a bad photo of the {}.',\n    'a cropped photo of the {}.',\n    'a tattoo of a {}.',\n    'the embroidered {}.',\n    'a photo of a hard to see {}.',\n    'a bright photo of a {}.',\n    'a photo of a clean {}.',\n    'a photo of a dirty {}.',\n    'a dark photo of the {}.',\n    'a drawing of a {}.',\n    'a photo of my {}.',\n    'the plastic {}.',\n    'a photo of the cool {}.',\n    'a close-up photo of a {}.',\n    'a black and white photo of the {}.',\n    'a painting of the {}.',\n    'a painting of a {}.',\n    'a pixelated photo of the {}.',\n    'a sculpture of the {}.',\n    'a bright photo of the {}.',\n    'a cropped photo of a {}.',\n    'a plastic {}.',\n    'a photo of the dirty {}.',\n    'a jpeg corrupted photo of a {}.',\n    'a blurry photo of the {}.',\n    'a photo of the {}.',\n    'a good photo of the {}.',\n    'a rendering of the {}.',\n    'a {} in a video game.',\n    'a photo of one {}.',\n    'a doodle of a {}.',\n    'a close-up photo of the {}.',\n    'a photo of a {}.',\n    'the origami {}.',\n    'the {} in a video game.',\n    'a sketch of a {}.',\n    'a doodle of the {}.',\n    'a origami {}.',\n    'a low resolution photo of a {}.',\n    'the toy {}.',\n    'a rendition of the {}.',\n    'a photo of the clean {}.',\n    'a photo of a large {}.',\n    'a rendition of a {}.',\n    'a photo of a nice {}.',\n    'a photo of a weird {}.',\n    'a blurry photo of a {}.',\n    'a cartoon {}.',\n    'art of a {}.',\n    'a sketch of the {}.',\n    'a embroidered {}.',\n    'a pixelated photo of a {}.',\n    'itap of the {}.',\n    'a jpeg corrupted photo of the {}.',\n    'a good photo of a {}.',\n    'a plushie {}.',\n    'a photo of the nice {}.',\n    'a photo of the small {}.',\n    'a photo of the weird {}.',\n    'the cartoon {}.',\n    'art of the {}.',\n    'a drawing of the {}.',\n    'a photo of the large {}.',\n    'a black and white photo of a {}.',\n    'the plushie {}.',\n    'a dark photo of a {}.',\n    'itap of a {}.',\n    'graffiti of the {}.',\n    'a toy {}.',\n    'itap of my {}.',\n    'a photo of a cool {}.',\n    'a photo of a small {}.',\n    'a tattoo of the {}.',\n]\nIMAGENET_SIMPLE_TEMPLATES = [\n    'a photo of {}.',\n]"
  },
  {
    "path": "utils/dist.py",
    "content": "import functools\nimport io\nimport os\nimport random \nimport subprocess\nimport time\nfrom collections import OrderedDict, defaultdict, deque\nimport datetime\nimport pickle\nfrom typing import Optional, List\n\nimport json, time\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor\n\nimport colorsys\ndef init_distributed_mode(args):\n    if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and \n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])\n\n        # launch by torch.distributed.launch\n        # Single node\n        #   python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...\n        # Multi nodes\n        #   python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...\n        #   python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...\n        # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))        \n        # local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])\n        # args.world_size = args.world_size * local_world_size\n        # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])\n        # args.rank = args.rank * local_world_size + args.local_rank\n        print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))\n        print(json.dumps(dict(os.environ), indent=2))\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])\n        args.world_size = int(os.environ['SLURM_NPROCS'])\n\n        if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1':\n            pass\n        else:\n            import util.hostlist as uh\n            nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST'])\n            gpu_ids = [int(node[3:]) for node in nodenames]\n            fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0))\n            # fixid += random.randint(0, 300)\n            port = str(3137 + int(min(gpu_ids)) + fixid)\n            args.dist_url = \"tcp://{ip}:{port}\".format(ip=uh.nodename_to_ip(nodenames[0]), port=port)\n\n        print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))\n\n\n    else:\n        print('Not using distributed mode')\n        args.distributed = False\n        args.world_size = 1\n        args.rank = 0\n        args.local_rank = 0\n        return\n\n    print(\"world_size:{} rank:{} local_rank:{}\".format(args.world_size, args.rank, args.local_rank))\n    args.distributed = True\n    torch.cuda.set_device(args.local_rank)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)\n\n    torch.distributed.init_process_group(\n        backend=args.dist_backend, \n        world_size=args.world_size, \n        rank=args.rank,\n        init_method=args.dist_url,\n    )\n\n    print(\"Before torch.distributed.barrier()\")\n    torch.distributed.barrier()\n    print(\"End torch.distributed.barrier()\")"
  },
  {
    "path": "utils/distributed.py",
    "content": "# import os\n# import time\n# import torch\n# import pickle\n# import subprocess\n#\n# from mpi4py import MPI\n# import torch.distributed as dist\n#\n#\n# def apply_distributed(opt):\n#     if opt['rank'] == 0:\n#         hostname_cmd = [\"hostname -I\"]\n#         result = subprocess.check_output(hostname_cmd, shell=True)\n#         master_address = result.decode('utf-8').split()[0]\n#         master_port = opt['PORT']\n#     else:\n#         master_address = None\n#         master_port = None\n#\n#     master_address = MPI.COMM_WORLD.bcast(master_address, root=0)\n#     master_port = MPI.COMM_WORLD.bcast(master_port, root=0)\n#\n#     if torch.distributed.is_available() and opt['world_size'] > 1:\n#         init_method_url = 'tcp://{}:{}'.format(master_address, master_port)\n#         backend = 'nccl'\n#         world_size = opt['world_size']\n#         rank = opt['rank']\n#         torch.distributed.init_process_group(backend=backend,\n#                                              init_method=init_method_url,\n#                                              world_size=world_size,\n#                                              rank=rank)\n#\n# def init_distributed(opt):\n#     opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()\n#     if 'OMPI_COMM_WORLD_SIZE' not in os.environ:\n#         # application was started without MPI\n#         # default to single node with single process\n#         opt['env_info'] = 'no MPI'\n#         opt['world_size'] = 1\n#         opt['local_size'] = 1\n#         opt['rank'] = 0\n#         opt['local_rank'] = 0\n#         opt['master_address'] = '127.0.0.1'\n#         opt['master_port'] = '8673'\n#     else:\n#         # application was started with MPI\n#         # get MPI parameters\n#         opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])\n#         opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])\n#         opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])\n#         opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])\n#\n#     # set up device\n#     if not opt['CUDA']:\n#         assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'\n#         opt['device'] = torch.device(\"cpu\")\n#     else:\n#         torch.cuda.set_device(opt['local_rank'])\n#         opt['device'] = torch.device(\"cuda\", opt['local_rank'])\n#\n#     apply_distributed(opt)\n#     return opt\n#\n# def is_main_process():\n#     rank = 0\n#     if 'OMPI_COMM_WORLD_SIZE' in os.environ:\n#         rank = int(os.environ['OMPI_COMM_WORLD_RANK'])\n#\n#     return rank == 0\n#\n# def get_world_size():\n#     if not dist.is_available():\n#         return 1\n#     if not dist.is_initialized():\n#         return 1\n#     return dist.get_world_size()\n#\n# def get_rank():\n#     if not dist.is_available():\n#         return 0\n#     if not dist.is_initialized():\n#         return 0\n#     return dist.get_rank()\n#\n#\n# def synchronize():\n#     \"\"\"\n#     Helper function to synchronize (barrier) among all processes when\n#     using distributed training\n#     \"\"\"\n#     if not dist.is_available():\n#         return\n#     if not dist.is_initialized():\n#         return\n#     world_size = dist.get_world_size()\n#     rank = dist.get_rank()\n#     if world_size == 1:\n#         return\n#\n#     def _send_and_wait(r):\n#         if rank == r:\n#             tensor = torch.tensor(0, device=\"cuda\")\n#         else:\n#             tensor = torch.tensor(1, device=\"cuda\")\n#         dist.broadcast(tensor, r)\n#         while tensor.item() == 1:\n#             time.sleep(1)\n#\n#     _send_and_wait(0)\n#     # now sync on the main process\n#     _send_and_wait(1)"
  },
  {
    "path": "utils/misc.py",
    "content": "# --------------------------------------------------------\n# X-Decoder -- Generalized Decoding for Pixel, Image, and Language\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Xueyan Zou (xueyan@cs.wisc.edu)\n# --------------------------------------------------------\nimport math\n\n\n# HACK for evalution \ndef hook_metadata(metadata, name):\n    if name == 'cityscapes_fine_sem_seg_val':\n        metadata.__setattr__(\"keep_sem_bgd\", False)\n    return metadata\n\ndef hook_opt(model, name):\n    if name in ['cityscapes_fine_panoptic_val', 'ade20k_panoptic_val', 'bdd10k_40_panoptic_val', 'cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val']:\n        model.model.object_mask_threshold = 0.4\n    else:\n        model.model.object_mask_threshold = 0.8\n\n# HACK for evalution \ndef hook_switcher(model, name):\n    mappings = {}\n    if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']:\n        mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False}\n    elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name:\n        mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False}\n    elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']:\n        mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True}\n    elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']:\n        mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True}\n    else:\n        if name not in [\"vlp_val\", \"vlp_captioning_val\", \"vlp_val2017\", \"vlp_captioning_val2017\", \"imagenet_val\", \"refcocog_val_google\", \"phrasecut_val\", \"phrasecut_test\", \"refcocop_val_unc\", \"refcoco_val_unc\", \"refcocog_val_umd\"]:\n            assert False, \"dataset switcher is not defined\"\n    for key, value in mappings.items():\n        if key == 'SEMANTIC_ON':\n            model.model.semantic_on = value\n        if key == 'INSTANCE_ON':\n            model.model.instance_on = value\n        if key == 'PANOPTIC_ON':\n            model.model.panoptic_on = value\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value.\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1, decay=0):\n        self.val = val\n        if decay:\n            alpha = math.exp(-n / decay)  # exponential decay over 100 updates\n            self.sum = alpha * self.sum + (1 - alpha) * val * n\n            self.count = alpha * self.count + (1 - alpha) * n\n        else:\n            self.sum += val * n\n            self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "utils/model.py",
    "content": "import logging\nimport os\nimport time\nimport pickle\nimport torch\n# from utils.distributed import is_main_process\n\nlogger = logging.getLogger(__name__)\n\n\nNORM_MODULES = [\n    torch.nn.BatchNorm1d,\n    torch.nn.BatchNorm2d,\n    torch.nn.BatchNorm3d,\n    torch.nn.SyncBatchNorm,\n    # NaiveSyncBatchNorm inherits from BatchNorm2d\n    torch.nn.GroupNorm,\n    torch.nn.InstanceNorm1d,\n    torch.nn.InstanceNorm2d,\n    torch.nn.InstanceNorm3d,\n    torch.nn.LayerNorm,\n    torch.nn.LocalResponseNorm,\n]\n\ndef register_norm_module(cls):\n    NORM_MODULES.append(cls)\n    return cls\n\ndef align_and_update_state_dicts(model_state_dict, ckpt_state_dict):\n    model_keys = sorted(model_state_dict.keys())\n    ckpt_keys = sorted(ckpt_state_dict.keys())\n    result_dicts = {}\n    matched_log = []\n    unmatched_log = []\n    unloaded_log = []\n    for model_key in model_keys:\n        model_weight = model_state_dict[model_key]\n        if model_key in ckpt_keys:\n            ckpt_weight = ckpt_state_dict[model_key]\n            if model_weight.shape == ckpt_weight.shape:\n                result_dicts[model_key] = ckpt_weight\n                ckpt_keys.pop(ckpt_keys.index(model_key))\n                matched_log.append(\"Loaded {}, Model Shape: {} <-> Ckpt Shape: {}\".format(model_key, model_weight.shape, ckpt_weight.shape))\n            else:\n                unmatched_log.append(\"*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}\".format(model_key, model_weight.shape, ckpt_weight.shape))\n        else:\n            unloaded_log.append(\"*UNLOADED* {}, Model Shape: {}\".format(model_key, model_weight.shape))\n            \n    # if is_main_process():\n    #     for info in matched_log:\n    #         logger.info(info)\n    #     for info in unloaded_log:\n    #         logger.warning(info)\n    #     for key in ckpt_keys:\n    #         logger.warning(\"$UNUSED$ {}, Ckpt Shape: {}\".format(key, ckpt_state_dict[key].shape))\n    #     for info in unmatched_log:\n    #         logger.warning(info)\n    return result_dicts"
  },
  {
    "path": "utils/nms.py",
    "content": "import torch\n\n\ndef matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None,thresh=0.7):\n    n_samples = len(cate_labels)\n    if n_samples == 0:\n        return []\n    if sum_masks is None:\n        sum_masks = seg_masks.sum((1, 2)).float()\n    seg_masks = seg_masks.reshape(n_samples, -1).float()\n    # inter.\n    inter_matrix = torch.mm(seg_masks.float(), seg_masks.float().transpose(1, 0))\n    # union.\n    sum_masks_x = sum_masks.expand(n_samples, n_samples)\n    # iou.\n    iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix+1e-5)).triu(diagonal=1)\n    result_idx=[]\n    for i in range(len(iou_matrix)):\n        if max(iou_matrix[:,i])<thresh:\n            result_idx.append(i)\n        else:\n            iou_matrix[:, i]=0.0\n            iou_matrix[i, :]=0.0\n    return result_idx\n\ndef matrix_nms_merge(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None, thresh=0.7,num_gt=0):\n    n_samples = len(cate_labels)\n    if n_samples == 0:\n        return []\n    if sum_masks is None:\n        sum_masks = seg_masks.sum((1, 2)).float()\n    seg_masks = seg_masks.reshape(n_samples, -1).float()\n    # inter.\n    inter_matrix = torch.mm(seg_masks.float(), seg_masks.float().transpose(1, 0))\n    # union.\n    sum_masks_x = sum_masks.expand(n_samples, n_samples)\n    # iou.\n    iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix + 1e-5)).triu(diagonal=1)\n    result_idx = []\n    gt_idx=[]\n    k=0\n    idx_map=dict()\n    for i in range(len(iou_matrix)):\n        if max(iou_matrix[:, i]) < thresh:\n            result_idx.append(i)\n            idx_map[i]=k\n\n            if i >= len(iou_matrix) - num_gt:\n                gt_idx.append(k)\n            k += 1\n\n        else:\n            if i >= len(iou_matrix) - num_gt:\n                gt_idx.append(idx_map[int(iou_matrix[:, i].max(0)[1])])\n            iou_matrix[:, i]=0.0\n            iou_matrix[i, :]=0.0\n\n    return result_idx,gt_idx\n\n\ndef multiclass_nms(multi_bboxes,\n                   multi_scores,\n                   score_thr,\n                   nms_cfg,\n                   max_num=-1,\n                   score_factors=None):\n    \"\"\"NMS for multi-class bboxes.\n\n    Args:\n        multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)\n        multi_scores (Tensor): shape (n, #class), where the 0th column\n            contains scores of the background class, but this will be ignored.\n        score_thr (float): bbox threshold, bboxes with scores lower than it\n            will not be considered.\n        nms_thr (float): NMS IoU threshold\n        max_num (int): if there are more than max_num bboxes after NMS,\n            only top max_num will be kept.\n        score_factors (Tensor): The factors multiplied to scores before\n            applying NMS\n\n    Returns:\n        tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels\n            are 0-based.\n    \"\"\"\n    num_classes = multi_scores.shape[1]\n    bboxes, labels = [], []\n    nms_cfg_ = nms_cfg.copy()\n    nms_type = nms_cfg_.pop('type', 'nms')\n    nms_op = getattr(nms_wrapper, nms_type)\n    for i in range(1, num_classes):\n        cls_inds = multi_scores[:, i] > score_thr\n        if not cls_inds.any():\n            continue\n        # get bboxes and scores of this class\n        if multi_bboxes.shape[1] == 4:\n            _bboxes = multi_bboxes[cls_inds, :]\n        else:\n            _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]\n        _scores = multi_scores[cls_inds, i]\n        if score_factors is not None:\n            _scores *= score_factors[cls_inds]\n        cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)\n        cls_dets, _ = nms_op(cls_dets, **nms_cfg_)\n        cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),\n                                           i - 1,\n                                           dtype=torch.long)\n        bboxes.append(cls_dets)\n        labels.append(cls_labels)\n    if bboxes:\n        bboxes = torch.cat(bboxes)\n        labels = torch.cat(labels)\n        if bboxes.shape[0] > max_num:\n            _, inds = bboxes[:, -1].sort(descending=True)\n            inds = inds[:max_num]\n            bboxes = bboxes[inds]\n            labels = labels[inds]\n    else:\n        bboxes = multi_bboxes.new_zeros((0, 5))\n        labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)\n\n    return bboxes, labels"
  },
  {
    "path": "utils/prompt_engineering.py",
    "content": "import numpy as np\n\n\ndef get_prompt_templates():\n    prompt_templates = [\n        '{}.',\n        'a photo of a {}.',\n        'a bad photo of a {}.',\n        'a photo of many {}.',\n        'a sculpture of a {}.',\n        'a photo of the hard to see {}.',\n        'a low resolution photo of the {}.',\n        'a rendering of a {}.',\n        'graffiti of a {}.',\n        'a bad photo of the {}.',\n        'a cropped photo of the {}.',\n        'a tattoo of a {}.',\n        'the embroidered {}.',\n        'a photo of a hard to see {}.',\n        'a bright photo of a {}.',\n        'a photo of a clean {}.',\n        'a photo of a dirty {}.',\n        'a dark photo of the {}.',\n        'a drawing of a {}.',\n        'a photo of my {}.',\n        'the plastic {}.',\n        'a photo of the cool {}.',\n        'a close-up photo of a {}.',\n        'a black and white photo of the {}.',\n        'a painting of the {}.',\n        'a painting of a {}.',\n        'a pixelated photo of the {}.',\n        'a sculpture of the {}.',\n        'a bright photo of the {}.',\n        'a cropped photo of a {}.',\n        'a plastic {}.',\n        'a photo of the dirty {}.',\n        'a jpeg corrupted photo of a {}.',\n        'a blurry photo of the {}.',\n        'a photo of the {}.',\n        'a good photo of the {}.',\n        'a rendering of the {}.',\n        'a {} in a video game.',\n        'a photo of one {}.',\n        'a doodle of a {}.',\n        'a close-up photo of the {}.',\n        'the origami {}.',\n        'the {} in a video game.',\n        'a sketch of a {}.',\n        'a doodle of the {}.',\n        'a origami {}.',\n        'a low resolution photo of a {}.',\n        'the toy {}.',\n        'a rendition of the {}.',\n        'a photo of the clean {}.',\n        'a photo of a large {}.',\n        'a rendition of a {}.',\n        'a photo of a nice {}.',\n        'a photo of a weird {}.',\n        'a blurry photo of a {}.',\n        'a cartoon {}.',\n        'art of a {}.',\n        'a sketch of the {}.',\n        'a embroidered {}.',\n        'a pixelated photo of a {}.',\n        'itap of the {}.',\n        'a jpeg corrupted photo of the {}.',\n        'a good photo of a {}.',\n        'a plushie {}.',\n        'a photo of the nice {}.',\n        'a photo of the small {}.',\n        'a photo of the weird {}.',\n        'the cartoon {}.',\n        'art of the {}.',\n        'a drawing of the {}.',\n        'a photo of the large {}.',\n        'a black and white photo of a {}.',\n        'the plushie {}.',\n        'a dark photo of a {}.',\n        'itap of a {}.',\n        'graffiti of the {}.',\n        'a toy {}.',\n        'itap of my {}.',\n        'a photo of a cool {}.',\n        'a photo of a small {}.',\n        'a tattoo of the {}.',\n    ]\n    return prompt_templates\n\ndef prompt_engineering(classnames, topk=1, suffix='.'):\n    prompt_templates = get_prompt_templates()\n    temp_idx = np.random.randint(min(len(prompt_templates), topk))\n\n    if isinstance(classnames, list):\n        classname = random.choice(classnames)\n    else:\n        classname = classnames\n\n    return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))"
  },
  {
    "path": "utils/utils.py",
    "content": "import torch\nimport numpy as np\n\ndef slprint(x, name='x'):\n    if isinstance(x, (torch.Tensor, np.ndarray)):\n        print(f'{name}.shape:', x.shape)\n    elif isinstance(x, (tuple, list)):\n        print('type x:', type(x))\n        for i in range(min(10, len(x))):\n            slprint(x[i], f'{name}[{i}]')\n    elif isinstance(x, dict):\n        for k,v in x.items():\n            slprint(v, f'{name}[{k}]')\n    else:\n        print(f'{name}.type:', type(x))"
  },
  {
    "path": "utils/visualizer.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport colorsys\nimport logging\nimport math\nimport numpy as np\nfrom enum import Enum, unique\nimport cv2\nimport matplotlib as mpl\nimport matplotlib.colors as mplc\nimport matplotlib.figure as mplfigure\nimport pycocotools.mask as mask_util\nimport torch\nfrom matplotlib.backends.backend_agg import FigureCanvasAgg\nfrom PIL import Image\n\nfrom detectron2.data import MetadataCatalog\nfrom detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes\nfrom detectron2.utils.file_io import PathManager\n\nfrom detectron2.utils.colormap import random_color\n\nlogger = logging.getLogger(__name__)\n\n__all__ = [\"ColorMode\", \"VisImage\", \"Visualizer\"]\n\n\n_SMALL_OBJECT_AREA_THRESH = 1000\n_LARGE_MASK_AREA_THRESH = 120000\n_OFF_WHITE = (1.0, 1.0, 240.0 / 255)\n_BLACK = (0, 0, 0)\n_RED = (1.0, 0, 0)\n\n_KEYPOINT_THRESHOLD = 0.05\n\n\n@unique\nclass ColorMode(Enum):\n    \"\"\"\n    Enum of different color modes to use for instance visualizations.\n    \"\"\"\n\n    IMAGE = 0\n    \"\"\"\n    Picks a random color for every instance and overlay segmentations with low opacity.\n    \"\"\"\n    SEGMENTATION = 1\n    \"\"\"\n    Let instances of the same category have similar colors\n    (from metadata.thing_colors), and overlay them with\n    high opacity. This provides more attention on the quality of segmentation.\n    \"\"\"\n    IMAGE_BW = 2\n    \"\"\"\n    Same as IMAGE, but convert all areas without masks to gray-scale.\n    Only available for drawing per-instance mask predictions.\n    \"\"\"\n\n\nclass GenericMask:\n    \"\"\"\n    Attribute:\n        polygons (list[ndarray]): list[ndarray]: polygons for this mask.\n            Each ndarray has format [x, y, x, y, ...]\n        mask (ndarray): a binary mask\n    \"\"\"\n\n    def __init__(self, mask_or_polygons, height, width):\n        self._mask = self._polygons = self._has_holes = None\n        self.height = height\n        self.width = width\n\n        m = mask_or_polygons\n        if isinstance(m, dict):\n            # RLEs\n            assert \"counts\" in m and \"size\" in m\n            if isinstance(m[\"counts\"], list):  # uncompressed RLEs\n                h, w = m[\"size\"]\n                assert h == height and w == width\n                m = mask_util.frPyObjects(m, h, w)\n            self._mask = mask_util.decode(m)[:, :]\n            return\n\n        if isinstance(m, list):  # list[ndarray]\n            self._polygons = [np.asarray(x).reshape(-1) for x in m]\n            return\n\n        if isinstance(m, np.ndarray):  # assumed to be a binary mask\n            assert m.shape[1] != 2, m.shape\n            assert m.shape == (\n                height,\n                width,\n            ), f\"mask shape: {m.shape}, target dims: {height}, {width}\"\n            self._mask = m.astype(\"uint8\")\n            return\n\n        raise ValueError(\"GenericMask cannot handle object {} of type '{}'\".format(m, type(m)))\n\n    @property\n    def mask(self):\n        if self._mask is None:\n            self._mask = self.polygons_to_mask(self._polygons)\n        return self._mask\n\n    @property\n    def polygons(self):\n        if self._polygons is None:\n            self._polygons, self._has_holes = self.mask_to_polygons(self._mask)\n        return self._polygons\n\n    @property\n    def has_holes(self):\n        if self._has_holes is None:\n            if self._mask is not None:\n                self._polygons, self._has_holes = self.mask_to_polygons(self._mask)\n            else:\n                self._has_holes = False  # if original format is polygon, does not have holes\n        return self._has_holes\n\n    def mask_to_polygons(self, mask):\n        # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level\n        # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.\n        # Internal contours (holes) are placed in hierarchy-2.\n        # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.\n        mask = np.ascontiguousarray(mask)  # some versions of cv2 does not support incontiguous arr\n        res = cv2.findContours(mask.astype(\"uint8\"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)\n        hierarchy = res[-1]\n        if hierarchy is None:  # empty mask\n            return [], False\n        has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0\n        res = res[-2]\n        res = [x.flatten() for x in res]\n        # These coordinates from OpenCV are integers in range [0, W-1 or H-1].\n        # We add 0.5 to turn them into real-value coordinate space. A better solution\n        # would be to first +0.5 and then dilate the returned polygon by 0.5.\n        res = [x + 0.5 for x in res if len(x) >= 6]\n        return res, has_holes\n\n    def polygons_to_mask(self, polygons):\n        rle = mask_util.frPyObjects(polygons, self.height, self.width)\n        rle = mask_util.merge(rle)\n        return mask_util.decode(rle)[:, :]\n\n    def area(self):\n        return self.mask.sum()\n\n    def bbox(self):\n        p = mask_util.frPyObjects(self.polygons, self.height, self.width)\n        p = mask_util.merge(p)\n        bbox = mask_util.toBbox(p)\n        bbox[2] += bbox[0]\n        bbox[3] += bbox[1]\n        return bbox\n\n\nclass _PanopticPrediction:\n    \"\"\"\n    Unify different panoptic annotation/prediction formats\n    \"\"\"\n\n    def __init__(self, panoptic_seg, segments_info, metadata=None):\n        if segments_info is None:\n            assert metadata is not None\n            # If \"segments_info\" is None, we assume \"panoptic_img\" is a\n            # H*W int32 image storing the panoptic_id in the format of\n            # category_id * label_divisor + instance_id. We reserve -1 for\n            # VOID label.\n            label_divisor = metadata.label_divisor\n            segments_info = []\n            for panoptic_label in np.unique(panoptic_seg.numpy()):\n                if panoptic_label == -1:\n                    # VOID region.\n                    continue\n                pred_class = panoptic_label // label_divisor\n                isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()\n                segments_info.append(\n                    {\n                        \"id\": int(panoptic_label),\n                        \"category_id\": int(pred_class),\n                        \"isthing\": bool(isthing),\n                    }\n                )\n        del metadata\n\n        self._seg = panoptic_seg\n\n        self._sinfo = {s[\"id\"]: s for s in segments_info}  # seg id -> seg info\n        segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)\n        areas = areas.numpy()\n        sorted_idxs = np.argsort(-areas)\n        self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]\n        self._seg_ids = self._seg_ids.tolist()\n        for sid, area in zip(self._seg_ids, self._seg_areas):\n            if sid in self._sinfo:\n                self._sinfo[sid][\"area\"] = float(area)\n\n    def non_empty_mask(self):\n        \"\"\"\n        Returns:\n            (H, W) array, a mask for all pixels that have a prediction\n        \"\"\"\n        empty_ids = []\n        for id in self._seg_ids:\n            if id not in self._sinfo:\n                empty_ids.append(id)\n        if len(empty_ids) == 0:\n            return np.zeros(self._seg.shape, dtype=np.uint8)\n        assert (\n            len(empty_ids) == 1\n        ), \">1 ids corresponds to no labels. This is currently not supported\"\n        return (self._seg != empty_ids[0]).numpy().astype(np.bool)\n\n    def semantic_masks(self):\n        for sid in self._seg_ids:\n            sinfo = self._sinfo.get(sid)\n            if sinfo is None or sinfo[\"isthing\"]:\n                # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.\n                continue\n            yield (self._seg == sid).numpy().astype(np.bool), sinfo\n\n    def instance_masks(self):\n        for sid in self._seg_ids:\n            sinfo = self._sinfo.get(sid)\n            if sinfo is None or not sinfo[\"isthing\"]:\n                continue\n            mask = (self._seg == sid).numpy().astype(np.bool)\n            if mask.sum() > 0:\n                yield mask, sinfo\n\n\ndef _create_text_labels(classes, scores, class_names, is_crowd=None):\n    \"\"\"\n    Args:\n        classes (list[int] or None):\n        scores (list[float] or None):\n        class_names (list[str] or None):\n        is_crowd (list[bool] or None):\n\n    Returns:\n        list[str] or None\n    \"\"\"\n    labels = None\n    if classes is not None:\n        if class_names is not None and len(class_names) > 0:\n            labels = [class_names[i] for i in classes]\n        else:\n            labels = [str(i) for i in classes]\n    if scores is not None:\n        if labels is None:\n            labels = [\"{:.0f}%\".format(s * 100) for s in scores]\n        else:\n            labels = [\"{} {:.0f}%\".format(l, s * 100) for l, s in zip(labels, scores)]\n    if labels is not None and is_crowd is not None:\n        labels = [l + (\"|crowd\" if crowd else \"\") for l, crowd in zip(labels, is_crowd)]\n    return labels\n\n\nclass VisImage:\n    def __init__(self, img, scale=1.0):\n        \"\"\"\n        Args:\n            img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].\n            scale (float): scale the input image\n        \"\"\"\n        self.img = img\n        self.scale = scale\n        self.width, self.height = img.shape[1], img.shape[0]\n        self._setup_figure(img)\n\n    def _setup_figure(self, img):\n        \"\"\"\n        Args:\n            Same as in :meth:`__init__()`.\n\n        Returns:\n            fig (matplotlib.pyplot.figure): top level container for all the image plot elements.\n            ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.\n        \"\"\"\n        fig = mplfigure.Figure(frameon=False)\n        self.dpi = fig.get_dpi()\n        # add a small 1e-2 to avoid precision lost due to matplotlib's truncation\n        # (https://github.com/matplotlib/matplotlib/issues/15363)\n        fig.set_size_inches(\n            (self.width * self.scale + 1e-2) / self.dpi,\n            (self.height * self.scale + 1e-2) / self.dpi,\n        )\n        self.canvas = FigureCanvasAgg(fig)\n        # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)\n        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])\n        ax.axis(\"off\")\n        self.fig = fig\n        self.ax = ax\n        self.reset_image(img)\n\n    def reset_image(self, img):\n        \"\"\"\n        Args:\n            img: same as in __init__\n        \"\"\"\n        img = img.astype(\"uint8\")\n        self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation=\"nearest\")\n\n    def save(self, filepath):\n        \"\"\"\n        Args:\n            filepath (str): a string that contains the absolute path, including the file name, where\n                the visualized image will be saved.\n        \"\"\"\n        self.fig.savefig(filepath)\n\n    def get_image(self):\n        \"\"\"\n        Returns:\n            ndarray:\n                the visualized image of shape (H, W, 3) (RGB) in uint8 type.\n                The shape is scaled w.r.t the input image using the given `scale` argument.\n        \"\"\"\n        canvas = self.canvas\n        s, (width, height) = canvas.print_to_buffer()\n        # buf = io.BytesIO()  # works for cairo backend\n        # canvas.print_rgba(buf)\n        # width, height = self.width, self.height\n        # s = buf.getvalue()\n\n        buffer = np.frombuffer(s, dtype=\"uint8\")\n\n        img_rgba = buffer.reshape(height, width, 4)\n        rgb, alpha = np.split(img_rgba, [3], axis=2)\n        return rgb.astype(\"uint8\")\n\n\nclass Visualizer:\n    \"\"\"\n    Visualizer that draws data about detection/segmentation on images.\n\n    It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`\n    that draw primitive objects to images, as well as high-level wrappers like\n    `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`\n    that draw composite data in some pre-defined style.\n\n    Note that the exact visualization style for the high-level wrappers are subject to change.\n    Style such as color, opacity, label contents, visibility of labels, or even the visibility\n    of objects themselves (e.g. when the object is too small) may change according\n    to different heuristics, as long as the results still look visually reasonable.\n\n    To obtain a consistent style, you can implement custom drawing functions with the\n    abovementioned primitive methods instead. If you need more customized visualization\n    styles, you can process the data yourself following their format documented in\n    tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not\n    intend to satisfy everyone's preference on drawing styles.\n\n    This visualizer focuses on high rendering quality rather than performance. It is not\n    designed to be used for real-time applications.\n    \"\"\"\n\n    # TODO implement a fast, rasterized version using OpenCV\n\n    def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):\n        \"\"\"\n        Args:\n            img_rgb: a numpy array of shape (H, W, C), where H and W correspond to\n                the height and width of the image respectively. C is the number of\n                color channels. The image is required to be in RGB format since that\n                is a requirement of the Matplotlib library. The image is also expected\n                to be in the range [0, 255].\n            metadata (Metadata): dataset metadata (e.g. class names and colors)\n            instance_mode (ColorMode): defines one of the pre-defined style for drawing\n                instances on an image.\n        \"\"\"\n        self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)\n        if metadata is None:\n            metadata = MetadataCatalog.get(\"__nonexist__\")\n        self.metadata = metadata\n        self.output = VisImage(self.img, scale=scale)\n        self.cpu_device = torch.device(\"cpu\")\n\n        # too small texts are useless, therefore clamp to 9\n        self._default_font_size = max(\n            np.sqrt(self.output.height * self.output.width) // 90, 10 // scale\n        )\n        self._default_font_size = 18\n        self._instance_mode = instance_mode\n        self.keypoint_threshold = _KEYPOINT_THRESHOLD\n\n    def draw_instance_predictions(self, predictions):\n        \"\"\"\n        Draw instance-level prediction results on an image.\n\n        Args:\n            predictions (Instances): the output of an instance detection/segmentation\n                model. Following fields will be used to draw:\n                \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\" (or \"pred_masks_rle\").\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        boxes = predictions.pred_boxes if predictions.has(\"pred_boxes\") else None\n        scores = predictions.scores if predictions.has(\"scores\") else None\n        classes = predictions.pred_classes.tolist() if predictions.has(\"pred_classes\") else None\n        labels = _create_text_labels(classes, scores, self.metadata.get(\"thing_classes\", None))\n        keypoints = predictions.pred_keypoints if predictions.has(\"pred_keypoints\") else None\n\n        keep = (scores > 0.5).cpu()\n        boxes = boxes[keep]\n        scores = scores[keep]\n        classes = np.array(classes)\n        classes = classes[np.array(keep)]\n        labels = np.array(labels)\n        labels = labels[np.array(keep)]\n\n        if predictions.has(\"pred_masks\"):\n            masks = np.asarray(predictions.pred_masks)\n            masks = masks[np.array(keep)]\n            masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]\n        else:\n            masks = None\n\n        if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(\"thing_colors\"):\n        # if self.metadata.get(\"thing_colors\"):\n            colors = [\n                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes\n            ]\n            alpha = 0.4\n        else:\n            colors = None\n            alpha = 0.4\n\n        if self._instance_mode == ColorMode.IMAGE_BW:\n            self.output.reset_image(\n                self._create_grayscale_image(\n                    (predictions.pred_masks.any(dim=0) > 0).numpy()\n                    if predictions.has(\"pred_masks\")\n                    else None\n                )\n            )\n            alpha = 0.3\n        \n        self.overlay_instances(\n            masks=masks,\n            boxes=boxes,\n            labels=labels,\n            keypoints=keypoints,\n            assigned_colors=colors,\n            alpha=alpha,\n        )\n        return self.output\n\n    def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):\n        \"\"\"\n        Draw semantic segmentation predictions/labels.\n\n        Args:\n            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).\n                Each value is the integer label of the pixel.\n            area_threshold (int): segments with less than `area_threshold` are not drawn.\n            alpha (float): the larger it is, the more opaque the segmentations are.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        if isinstance(sem_seg, torch.Tensor):\n            sem_seg = sem_seg.numpy()\n        labels, areas = np.unique(sem_seg, return_counts=True)\n        sorted_idxs = np.argsort(-areas).tolist()\n        labels = labels[sorted_idxs]\n        for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):\n            try:\n                mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]\n            except (AttributeError, IndexError):\n                mask_color = None\n\n            binary_mask = (sem_seg == label).astype(np.uint8)\n            text = self.metadata.stuff_classes[label]\n            self.draw_binary_mask(\n                binary_mask,\n                color=mask_color,\n                edge_color=_OFF_WHITE,\n                text=text,\n                alpha=alpha,\n                area_threshold=area_threshold,\n            )\n        return self.output\n\n    def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):\n        \"\"\"\n        Draw panoptic prediction annotations or results.\n\n        Args:\n            panoptic_seg (Tensor): of shape (height, width) where the values are ids for each\n                segment.\n            segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.\n                If it is a ``list[dict]``, each dict contains keys \"id\", \"category_id\".\n                If None, category id of each pixel is computed by\n                ``pixel // metadata.label_divisor``.\n            area_threshold (int): stuff segments with less than `area_threshold` are not drawn.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)\n\n        if self._instance_mode == ColorMode.IMAGE_BW:\n            self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))\n\n        # draw mask for all semantic segments first i.e. \"stuff\"\n        for mask, sinfo in pred.semantic_masks():\n            category_idx = sinfo[\"category_id\"]\n            try:\n                mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]\n            except AttributeError:\n                mask_color = None\n\n            text = self.metadata.stuff_classes[category_idx]\n            self.draw_binary_mask(\n                mask,\n                color=mask_color,\n                edge_color=_OFF_WHITE,\n                text=text,\n                alpha=alpha,\n                area_threshold=area_threshold,\n            )\n\n        # draw mask for all instances second\n        all_instances = list(pred.instance_masks())\n        if len(all_instances) == 0:\n            return self.output\n        masks, sinfo = list(zip(*all_instances))\n        category_ids = [x[\"category_id\"] for x in sinfo]\n\n        try:\n            scores = [x[\"score\"] for x in sinfo]\n        except KeyError:\n            scores = None\n        labels = _create_text_labels(\n            category_ids, scores, self.metadata.thing_classes, [x.get(\"iscrowd\", 0) for x in sinfo]\n        )\n\n        try:\n            colors = [\n                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids\n            ]\n        except AttributeError:\n            colors = None\n        self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)\n\n        return self.output\n\n    draw_panoptic_seg_predictions = draw_panoptic_seg  # backward compatibility\n\n    def draw_dataset_dict(self, dic):\n        \"\"\"\n        Draw annotations/segmentaions in Detectron2 Dataset format.\n\n        Args:\n            dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        annos = dic.get(\"annotations\", None)\n        if annos:\n            if \"segmentation\" in annos[0]:\n                masks = [x[\"segmentation\"] for x in annos]\n            else:\n                masks = None\n            if \"keypoints\" in annos[0]:\n                keypts = [x[\"keypoints\"] for x in annos]\n                keypts = np.array(keypts).reshape(len(annos), -1, 3)\n            else:\n                keypts = None\n\n            boxes = [\n                BoxMode.convert(x[\"bbox\"], x[\"bbox_mode\"], BoxMode.XYXY_ABS)\n                if len(x[\"bbox\"]) == 4\n                else x[\"bbox\"]\n                for x in annos\n            ]\n\n            colors = None\n            category_ids = [x[\"category_id\"] for x in annos]\n            if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(\"thing_colors\"):\n                colors = [\n                    self._jitter([x / 255 for x in self.metadata.thing_colors[c]])\n                    for c in category_ids\n                ]\n            names = self.metadata.get(\"thing_classes\", None)\n            labels = _create_text_labels(\n                category_ids,\n                scores=None,\n                class_names=names,\n                is_crowd=[x.get(\"iscrowd\", 0) for x in annos],\n            )\n            self.overlay_instances(\n                labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors\n            )\n\n        sem_seg = dic.get(\"sem_seg\", None)\n        if sem_seg is None and \"sem_seg_file_name\" in dic:\n            with PathManager.open(dic[\"sem_seg_file_name\"], \"rb\") as f:\n                sem_seg = Image.open(f)\n                sem_seg = np.asarray(sem_seg, dtype=\"uint8\")\n        if sem_seg is not None:\n            self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)\n\n        pan_seg = dic.get(\"pan_seg\", None)\n        if pan_seg is None and \"pan_seg_file_name\" in dic:\n            with PathManager.open(dic[\"pan_seg_file_name\"], \"rb\") as f:\n                pan_seg = Image.open(f)\n                pan_seg = np.asarray(pan_seg)\n                from panopticapi.utils import rgb2id\n\n                pan_seg = rgb2id(pan_seg)\n        if pan_seg is not None:\n            segments_info = dic[\"segments_info\"]\n            pan_seg = torch.tensor(pan_seg)\n            self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)\n        return self.output\n\n    def overlay_instances(\n        self,\n        *,\n        boxes=None,\n        labels=None,\n        masks=None,\n        keypoints=None,\n        assigned_colors=None,\n        alpha=0.5,\n    ):\n        \"\"\"\n        Args:\n            boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,\n                or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,\n                or a :class:`RotatedBoxes`,\n                or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format\n                for the N objects in a single image,\n            labels (list[str]): the text to be displayed for each instance.\n            masks (masks-like object): Supported types are:\n\n                * :class:`detectron2.structures.PolygonMasks`,\n                  :class:`detectron2.structures.BitMasks`.\n                * list[list[ndarray]]: contains the segmentation masks for all objects in one image.\n                  The first level of the list corresponds to individual instances. The second\n                  level to all the polygon that compose the instance, and the third level\n                  to the polygon coordinates. The third level should have the format of\n                  [x0, y0, x1, y1, ..., xn, yn] (n >= 3).\n                * list[ndarray]: each ndarray is a binary mask of shape (H, W).\n                * list[dict]: each dict is a COCO-style RLE.\n            keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),\n                where the N is the number of instances and K is the number of keypoints.\n                The last dimension corresponds to (x, y, visibility or score).\n            assigned_colors (list[matplotlib.colors]): a list of colors, where each color\n                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'\n                for full list of formats that the colors are accepted in.\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        num_instances = 0\n        if boxes is not None:\n            boxes = self._convert_boxes(boxes)\n            num_instances = len(boxes)\n        if masks is not None:\n            masks = self._convert_masks(masks)\n            if num_instances:\n                assert len(masks) == num_instances\n            else:\n                num_instances = len(masks)\n        if keypoints is not None:\n            if num_instances:\n                assert len(keypoints) == num_instances\n            else:\n                num_instances = len(keypoints)\n            keypoints = self._convert_keypoints(keypoints)\n        if labels is not None:\n            assert len(labels) == num_instances\n        if assigned_colors is None:\n            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]\n        if num_instances == 0:\n            return self.output\n        if boxes is not None and boxes.shape[1] == 5:\n            return self.overlay_rotated_instances(\n                boxes=boxes, labels=labels, assigned_colors=assigned_colors\n            )\n\n        # Display in largest to smallest order to reduce occlusion.\n        areas = None\n        if boxes is not None:\n            areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)\n        elif masks is not None:\n            areas = np.asarray([x.area() for x in masks])\n\n        if areas is not None:\n            sorted_idxs = np.argsort(-areas).tolist()\n            # Re-order overlapped instances in descending order.\n            boxes = boxes[sorted_idxs] if boxes is not None else None\n            labels = [labels[k] for k in sorted_idxs] if labels is not None else None\n            masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None\n            assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]\n            keypoints = keypoints[sorted_idxs] if keypoints is not None else None\n\n        for i in range(num_instances):\n            color = assigned_colors[i]\n            if boxes is not None:\n                self.draw_box(boxes[i], edge_color=color)\n\n            if masks is not None:\n                for segment in masks[i].polygons:\n                    self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)\n\n            if labels is not None:\n                # first get a box\n                if boxes is not None:\n                    x0, y0, x1, y1 = boxes[i]\n                    text_pos = (x0, y0)  # if drawing boxes, put text on the box corner.\n                    horiz_align = \"left\"\n                elif masks is not None:\n                    # skip small mask without polygon\n                    if len(masks[i].polygons) == 0:\n                        continue\n\n                    x0, y0, x1, y1 = masks[i].bbox()\n\n                    # draw text in the center (defined by median) when box is not drawn\n                    # median is less sensitive to outliers.\n                    text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]\n                    horiz_align = \"center\"\n                else:\n                    continue  # drawing the box confidence for keypoints isn't very useful.\n                # for small objects, draw text at the side to avoid occlusion\n                instance_area = (y1 - y0) * (x1 - x0)\n                if (\n                    instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale\n                    or y1 - y0 < 40 * self.output.scale\n                ):\n                    if y1 >= self.output.height - 5:\n                        text_pos = (x1, y0)\n                    else:\n                        text_pos = (x0, y1)\n\n                height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)\n                lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n                font_size = (\n                    np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)\n                    * 0.5\n                    * self._default_font_size\n                )\n                self.draw_text(\n                    labels[i],\n                    text_pos,\n                    color=lighter_color,\n                    horizontal_alignment=horiz_align,\n                    font_size=font_size,\n                )\n\n        # draw keypoints\n        if keypoints is not None:\n            for keypoints_per_instance in keypoints:\n                self.draw_and_connect_keypoints(keypoints_per_instance)\n\n        return self.output\n\n    def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):\n        \"\"\"\n        Args:\n            boxes (ndarray): an Nx5 numpy array of\n                (x_center, y_center, width, height, angle_degrees) format\n                for the N objects in a single image.\n            labels (list[str]): the text to be displayed for each instance.\n            assigned_colors (list[matplotlib.colors]): a list of colors, where each color\n                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'\n                for full list of formats that the colors are accepted in.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        num_instances = len(boxes)\n\n        if assigned_colors is None:\n            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]\n        if num_instances == 0:\n            return self.output\n\n        # Display in largest to smallest order to reduce occlusion.\n        if boxes is not None:\n            areas = boxes[:, 2] * boxes[:, 3]\n\n        sorted_idxs = np.argsort(-areas).tolist()\n        # Re-order overlapped instances in descending order.\n        boxes = boxes[sorted_idxs]\n        labels = [labels[k] for k in sorted_idxs] if labels is not None else None\n        colors = [assigned_colors[idx] for idx in sorted_idxs]\n\n        for i in range(num_instances):\n            self.draw_rotated_box_with_label(\n                boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None\n            )\n\n        return self.output\n\n    def draw_and_connect_keypoints(self, keypoints):\n        \"\"\"\n        Draws keypoints of an instance and follows the rules for keypoint connections\n        to draw lines between appropriate keypoints. This follows color heuristics for\n        line color.\n\n        Args:\n            keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints\n                and the last dimension corresponds to (x, y, probability).\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        visible = {}\n        keypoint_names = self.metadata.get(\"keypoint_names\")\n        for idx, keypoint in enumerate(keypoints):\n\n            # draw keypoint\n            x, y, prob = keypoint\n            if prob > self.keypoint_threshold:\n                self.draw_circle((x, y), color=_RED)\n                if keypoint_names:\n                    keypoint_name = keypoint_names[idx]\n                    visible[keypoint_name] = (x, y)\n\n        if self.metadata.get(\"keypoint_connection_rules\"):\n            for kp0, kp1, color in self.metadata.keypoint_connection_rules:\n                if kp0 in visible and kp1 in visible:\n                    x0, y0 = visible[kp0]\n                    x1, y1 = visible[kp1]\n                    color = tuple(x / 255.0 for x in color)\n                    self.draw_line([x0, x1], [y0, y1], color=color)\n\n        # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip\n        # Note that this strategy is specific to person keypoints.\n        # For other keypoints, it should just do nothing\n        try:\n            ls_x, ls_y = visible[\"left_shoulder\"]\n            rs_x, rs_y = visible[\"right_shoulder\"]\n            mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2\n        except KeyError:\n            pass\n        else:\n            # draw line from nose to mid-shoulder\n            nose_x, nose_y = visible.get(\"nose\", (None, None))\n            if nose_x is not None:\n                self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)\n\n            try:\n                # draw line from mid-shoulder to mid-hip\n                lh_x, lh_y = visible[\"left_hip\"]\n                rh_x, rh_y = visible[\"right_hip\"]\n            except KeyError:\n                pass\n            else:\n                mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2\n                self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)\n        return self.output\n\n    \"\"\"\n    Primitive drawing functions:\n    \"\"\"\n\n    def draw_text(\n        self,\n        text,\n        position,\n        *,\n        font_size=None,\n        color=\"g\",\n        horizontal_alignment=\"center\",\n        rotation=0,\n    ):\n        \"\"\"\n        Args:\n            text (str): class label\n            position (tuple): a tuple of the x and y coordinates to place text on image.\n            font_size (int, optional): font of the text. If not provided, a font size\n                proportional to the image width is calculated and used.\n            color: color of the text. Refer to `matplotlib.colors` for full list\n                of formats that are accepted.\n            horizontal_alignment (str): see `matplotlib.text.Text`\n            rotation: rotation angle in degrees CCW\n\n        Returns:\n            output (VisImage): image object with text drawn.\n        \"\"\"\n        if not font_size:\n            font_size = self._default_font_size\n\n        # since the text background is dark, we don't want the text to be dark\n        color = np.maximum(list(mplc.to_rgb(color)), 0.2)\n        color[np.argmax(color)] = max(0.8, np.max(color))\n\n        x, y = position\n        self.output.ax.text(\n            x,\n            y,\n            text,\n            size=font_size * self.output.scale,\n            family=\"sans-serif\",\n            bbox={\"facecolor\": \"black\", \"alpha\": 0.8, \"pad\": 0.7, \"edgecolor\": \"none\"},\n            verticalalignment=\"top\",\n            horizontalalignment=horizontal_alignment,\n            color=color,\n            zorder=10,\n            rotation=rotation,\n        )\n        return self.output\n\n    def draw_box(self, box_coord, alpha=0.5, edge_color=\"g\", line_style=\"-\"):\n        \"\"\"\n        Args:\n            box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0\n                are the coordinates of the image's top left corner. x1 and y1 are the\n                coordinates of the image's bottom right corner.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            edge_color: color of the outline of the box. Refer to `matplotlib.colors`\n                for full list of formats that are accepted.\n            line_style (string): the string to use to create the outline of the boxes.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        x0, y0, x1, y1 = box_coord\n        width = x1 - x0\n        height = y1 - y0\n\n        linewidth = max(self._default_font_size / 4, 1)\n\n        self.output.ax.add_patch(\n            mpl.patches.Rectangle(\n                (x0, y0),\n                width,\n                height,\n                fill=False,\n                edgecolor=edge_color,\n                linewidth=linewidth * self.output.scale,\n                alpha=alpha,\n                linestyle=line_style,\n            )\n        )\n        return self.output\n\n    def draw_rotated_box_with_label(\n        self, rotated_box, alpha=0.5, edge_color=\"g\", line_style=\"-\", label=None\n    ):\n        \"\"\"\n        Draw a rotated box with label on its top-left corner.\n\n        Args:\n            rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),\n                where cnt_x and cnt_y are the center coordinates of the box.\n                w and h are the width and height of the box. angle represents how\n                many degrees the box is rotated CCW with regard to the 0-degree box.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            edge_color: color of the outline of the box. Refer to `matplotlib.colors`\n                for full list of formats that are accepted.\n            line_style (string): the string to use to create the outline of the boxes.\n            label (string): label for rotated box. It will not be rendered when set to None.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        cnt_x, cnt_y, w, h, angle = rotated_box\n        area = w * h\n        # use thinner lines when the box is small\n        linewidth = self._default_font_size / (\n            6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3\n        )\n\n        theta = angle * math.pi / 180.0\n        c = math.cos(theta)\n        s = math.sin(theta)\n        rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]\n        # x: left->right ; y: top->down\n        rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]\n        for k in range(4):\n            j = (k + 1) % 4\n            self.draw_line(\n                [rotated_rect[k][0], rotated_rect[j][0]],\n                [rotated_rect[k][1], rotated_rect[j][1]],\n                color=edge_color,\n                linestyle=\"--\" if k == 1 else line_style,\n                linewidth=linewidth,\n            )\n\n        if label is not None:\n            text_pos = rotated_rect[1]  # topleft corner\n\n            height_ratio = h / np.sqrt(self.output.height * self.output.width)\n            label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)\n            font_size = (\n                np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size\n            )\n            self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)\n\n        return self.output\n\n    def draw_circle(self, circle_coord, color, radius=3):\n        \"\"\"\n        Args:\n            circle_coord (list(int) or tuple(int)): contains the x and y coordinates\n                of the center of the circle.\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            radius (int): radius of the circle.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        x, y = circle_coord\n        self.output.ax.add_patch(\n            mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)\n        )\n        return self.output\n\n    def draw_line(self, x_data, y_data, color, linestyle=\"-\", linewidth=None):\n        \"\"\"\n        Args:\n            x_data (list[int]): a list containing x values of all the points being drawn.\n                Length of list should match the length of y_data.\n            y_data (list[int]): a list containing y values of all the points being drawn.\n                Length of list should match the length of x_data.\n            color: color of the line. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            linestyle: style of the line. Refer to `matplotlib.lines.Line2D`\n                for a full list of formats that are accepted.\n            linewidth (float or None): width of the line. When it's None,\n                a default value will be computed and used.\n\n        Returns:\n            output (VisImage): image object with line drawn.\n        \"\"\"\n        if linewidth is None:\n            linewidth = self._default_font_size / 3\n        linewidth = max(linewidth, 1)\n        self.output.ax.add_line(\n            mpl.lines.Line2D(\n                x_data,\n                y_data,\n                linewidth=linewidth * self.output.scale,\n                color=color,\n                linestyle=linestyle,\n            )\n        )\n        return self.output\n\n    def draw_binary_mask(\n        self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10\n    ):\n        \"\"\"\n        Args:\n            binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and\n                W is the image width. Each value in the array is either a 0 or 1 value of uint8\n                type.\n            color: color of the mask. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted. If None, will pick a random color.\n            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a\n                full list of formats that are accepted.\n            text (str): if None, will be drawn on the object\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            area_threshold (float): a connected component smaller than this area will not be shown.\n\n        Returns:\n            output (VisImage): image object with mask drawn.\n        \"\"\"\n        if color is None:\n            color = random_color(rgb=True, maximum=1)\n        color = mplc.to_rgb(color)\n\n        has_valid_segment = False\n        binary_mask = binary_mask.astype(\"uint8\")  # opencv needs uint8\n        mask = GenericMask(binary_mask, self.output.height, self.output.width)\n        shape2d = (binary_mask.shape[0], binary_mask.shape[1])\n\n        if not mask.has_holes:\n            # draw polygons for regular masks\n            for segment in mask.polygons:\n                area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))\n                if area < (area_threshold or 0):\n                    continue\n                has_valid_segment = True\n                segment = segment.reshape(-1, 2)\n                self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)\n        else:\n            # TODO: Use Path/PathPatch to draw vector graphics:\n            # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon\n            rgba = np.zeros(shape2d + (4,), dtype=\"float32\")\n            rgba[:, :, :3] = color\n            rgba[:, :, 3] = (mask.mask == 1).astype(\"float32\") * alpha\n            has_valid_segment = True\n            self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))\n\n        if text is not None and has_valid_segment:\n            lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n            self._draw_text_in_mask(binary_mask, text, lighter_color)\n        return self.output\n\n    def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):\n        \"\"\"\n        Args:\n            soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].\n            color: color of the mask. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted. If None, will pick a random color.\n            text (str): if None, will be drawn on the object\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n\n        Returns:\n            output (VisImage): image object with mask drawn.\n        \"\"\"\n        if color is None:\n            color = random_color(rgb=True, maximum=1)\n        color = mplc.to_rgb(color)\n\n        shape2d = (soft_mask.shape[0], soft_mask.shape[1])\n        rgba = np.zeros(shape2d + (4,), dtype=\"float32\")\n        rgba[:, :, :3] = color\n        rgba[:, :, 3] = soft_mask * alpha\n        self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))\n\n        if text is not None:\n            lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n            binary_mask = (soft_mask > 0.5).astype(\"uint8\")\n            self._draw_text_in_mask(binary_mask, text, lighter_color)\n        return self.output\n\n    def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):\n        \"\"\"\n        Args:\n            segment: numpy array of shape Nx2, containing all the points in the polygon.\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a\n                full list of formats that are accepted. If not provided, a darker shade\n                of the polygon color will be used instead.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n\n        Returns:\n            output (VisImage): image object with polygon drawn.\n        \"\"\"\n        if edge_color is None:\n            # make edge color darker than the polygon color\n            if alpha > 0.8:\n                edge_color = self._change_color_brightness(color, brightness_factor=-0.7)\n            else:\n                edge_color = color\n        edge_color = mplc.to_rgb(edge_color) + (1,)\n\n        polygon = mpl.patches.Polygon(\n            segment,\n            fill=True,\n            facecolor=mplc.to_rgb(color) + (alpha,),\n            edgecolor=edge_color,\n            linewidth=max(self._default_font_size // 15 * self.output.scale, 1),\n        )\n        self.output.ax.add_patch(polygon)\n        return self.output\n\n    \"\"\"\n    Internal methods:\n    \"\"\"\n\n    def _jitter(self, color):\n        \"\"\"\n        Randomly modifies given color to produce a slightly different color than the color given.\n\n        Args:\n            color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color\n                picked. The values in the list are in the [0.0, 1.0] range.\n\n        Returns:\n            jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the\n                color after being jittered. The values in the list are in the [0.0, 1.0] range.\n        \"\"\"\n        color = mplc.to_rgb(color)\n        # np.random.seed(0)\n        vec = np.random.rand(3)\n        # better to do it in another color space\n        vec = vec / np.linalg.norm(vec) * 0.5\n        res = np.clip(vec + color, 0, 1)\n        return tuple(res)\n\n    def _create_grayscale_image(self, mask=None):\n        \"\"\"\n        Create a grayscale version of the original image.\n        The colors in masked area, if given, will be kept.\n        \"\"\"\n        img_bw = self.img.astype(\"f4\").mean(axis=2)\n        img_bw = np.stack([img_bw] * 3, axis=2)\n        if mask is not None:\n            img_bw[mask] = self.img[mask]\n        return img_bw\n\n    def _change_color_brightness(self, color, brightness_factor):\n        \"\"\"\n        Depending on the brightness_factor, gives a lighter or darker color i.e. a color with\n        less or more saturation than the original color.\n\n        Args:\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of\n                0 will correspond to no change, a factor in [-1.0, 0) range will result in\n                a darker color and a factor in (0, 1.0] range will result in a lighter color.\n\n        Returns:\n            modified_color (tuple[double]): a tuple containing the RGB values of the\n                modified color. Each value in the tuple is in the [0.0, 1.0] range.\n        \"\"\"\n        assert brightness_factor >= -1.0 and brightness_factor <= 1.0\n        color = mplc.to_rgb(color)\n        polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))\n        modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])\n        modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness\n        modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness\n        modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])\n        return modified_color\n\n    def _convert_boxes(self, boxes):\n        \"\"\"\n        Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.\n        \"\"\"\n        if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):\n            return boxes.tensor.detach().numpy()\n        else:\n            return np.asarray(boxes)\n\n    def _convert_masks(self, masks_or_polygons):\n        \"\"\"\n        Convert different format of masks or polygons to a tuple of masks and polygons.\n\n        Returns:\n            list[GenericMask]:\n        \"\"\"\n\n        m = masks_or_polygons\n        if isinstance(m, PolygonMasks):\n            m = m.polygons\n        if isinstance(m, BitMasks):\n            m = m.tensor.numpy()\n        if isinstance(m, torch.Tensor):\n            m = m.numpy()\n        ret = []\n        for x in m:\n            if isinstance(x, GenericMask):\n                ret.append(x)\n            else:\n                ret.append(GenericMask(x, self.output.height, self.output.width))\n        return ret\n\n    def _draw_text_in_mask(self, binary_mask, text, color):\n        \"\"\"\n        Find proper places to draw text given a binary mask.\n        \"\"\"\n        # TODO sometimes drawn on wrong objects. the heuristics here can improve.\n        _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)\n        if stats[1:, -1].size == 0:\n            return\n        largest_component_id = np.argmax(stats[1:, -1]) + 1\n\n        # draw text on the largest component, as well as other very large components.\n        for cid in range(1, _num_cc):\n            if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:\n                # median is more stable than centroid\n                # center = centroids[largest_component_id]\n                center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]\n                self.draw_text(text, center, color=color)\n\n    def _convert_keypoints(self, keypoints):\n        if isinstance(keypoints, Keypoints):\n            keypoints = keypoints.tensor\n        keypoints = np.asarray(keypoints)\n        return keypoints\n\n    def get_output(self):\n        \"\"\"\n        Returns:\n            output (VisImage): the image output containing the visualizations added\n            to the image.\n        \"\"\"\n        return self.output"
  }
]