Repository: UX-Decoder/LLaVA-Grounding Branch: main Commit: 668b7cc3d536 Files: 257 Total size: 2.6 MB Directory structure: gitextract_hcvjbjkn/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── openseed/ │ │ ├── openseed_swint_lang_joint.yaml │ │ ├── openseed_swint_lang_joint_2st.yaml │ │ └── openseed_swint_lang_joint_2st_visual_prompt.yaml │ └── semsam/ │ └── visual_prompt_encoder.yaml ├── datasets_os/ │ ├── __init__.py │ ├── build.py │ ├── custom_dataset_dataloader.py │ ├── dataset_mappers/ │ │ ├── __init__.py │ │ ├── coco_instance_new_baseline_dataset_mapper.py │ │ ├── coco_instruct_grounding_dataset_interactive_mapper.py │ │ ├── coco_instruct_grounding_dataset_mapper.py │ │ ├── coco_interactive_panoptic_new_baseline_dataset_mapper.py │ │ ├── coco_panoptic_interactive_dataset_mapper.py │ │ ├── coco_panoptic_new_baseline_dataset_mapper.py │ │ ├── flickr_instance_new_baseline_dataset_mapper.py │ │ ├── flickr_instance_new_baseline_dataset_mapper_.py │ │ ├── flickr_instance_new_baseline_dataset_mapper_end.py │ │ ├── flickr_new_baseline_dataset_mapper.py │ │ ├── inference_mapper_with_gt.py │ │ ├── llava_dataset_mapper.py │ │ ├── refcoco_dataset_mapper.py │ │ └── vg_instance_new_baseline_dataset_mapper.py │ ├── refer.py │ ├── registration/ │ │ ├── __init__.py │ │ ├── register_coco_instruct_grounding_dataset.py │ │ ├── register_coco_panoptic_annos_grounding_interactive.py │ │ ├── register_flickr_dataset.py │ │ └── register_vg_dataset.py │ └── semseg_loader.py ├── docs/ │ └── MODEL_ZOO.md ├── gradio_demo/ │ ├── LLaVA_G_Demo.py │ └── __init__.py ├── llava/ │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── eval/ │ │ ├── LLaVA_G_Eval.py │ │ ├── eval_gpt_review.py │ │ ├── eval_gpt_review_bench.py │ │ ├── eval_gpt_review_visual.py │ │ ├── eval_gpt_review_visual2.py │ │ ├── eval_science_qa.py │ │ ├── eval_science_qa_gpt4.py │ │ ├── eval_science_qa_gpt4_requery.py │ │ ├── generate_webpage_data_from_table.py │ │ ├── llava_mapper.py │ │ ├── model_qa.py │ │ ├── model_vqa.py │ │ ├── model_vqa_science.py │ │ ├── qa_baseline_gpt35.py │ │ ├── run_llava.py │ │ ├── summarize_gpt_review.py │ │ └── webpage/ │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css │ ├── mm_utils.py │ ├── model/ │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model/ │ │ │ ├── llava_llama.py │ │ │ ├── llava_llama_gd.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt/ │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder/ │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── openseed/ │ │ │ ├── BaseModel.py │ │ │ ├── __init__.py │ │ │ ├── architectures/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── openseed_model.py │ │ │ │ ├── openseed_model_decouple_train.py │ │ │ │ └── registry.py │ │ │ ├── backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── focal.py │ │ │ │ ├── focal_dw.py │ │ │ │ ├── registry.py │ │ │ │ └── swin.py │ │ │ ├── body/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── decoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── modules.py │ │ │ │ │ ├── openseed_decoder.py │ │ │ │ │ ├── openseed_decoder_decouple.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dino_decoder.py │ │ │ │ │ └── utils.py │ │ │ │ ├── encoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── encoder_deform.py │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── functions/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ │ ├── make.sh │ │ │ │ │ │ ├── modules/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn.py │ │ │ │ │ │ ├── setup.py │ │ │ │ │ │ ├── src/ │ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ │ └── test.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer_encoder_fpn.py │ │ │ │ ├── openseed_head.py │ │ │ │ ├── registry.py │ │ │ │ └── transformer_blocks.py │ │ │ ├── language/ │ │ │ │ ├── LangEncoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── encoder.py │ │ │ │ ├── registry.py │ │ │ │ └── vlpencoder.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── criterion.py │ │ │ │ ├── matcher.py │ │ │ │ ├── point_features.py │ │ │ │ ├── position_encoding.py │ │ │ │ └── postprocessing.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── box_ops.py │ │ │ ├── config.py │ │ │ └── misc.py │ │ ├── semsam/ │ │ │ ├── BaseModel.py │ │ │ ├── __init__.py │ │ │ ├── architectures/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1.py │ │ │ │ └── registry.py │ │ │ ├── backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── focal.py │ │ │ │ ├── focal_dw.py │ │ │ │ ├── registry.py │ │ │ │ ├── swin.py │ │ │ │ └── swin_new.py │ │ │ ├── body/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── decoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── idino_decoder_no_iou_token_partwhole_all_llm.py │ │ │ │ │ ├── modules.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dino_decoder.py │ │ │ │ │ └── utils.py │ │ │ │ ├── encoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── encoder_deform.py │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── functions/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ │ ├── make.sh │ │ │ │ │ │ ├── modules/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn.py │ │ │ │ │ │ ├── setup.py │ │ │ │ │ │ ├── src/ │ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ │ └── test.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer_encoder_fpn.py │ │ │ │ ├── openseed_head.py │ │ │ │ ├── registry.py │ │ │ │ └── transformer_blocks.py │ │ │ ├── language/ │ │ │ │ ├── LangEncoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── encoder.py │ │ │ │ ├── fixencoder.py │ │ │ │ ├── llama_encoder.py │ │ │ │ ├── loss.py │ │ │ │ ├── misc.py │ │ │ │ ├── modeling_llama_os.py │ │ │ │ ├── registry.py │ │ │ │ └── vlpencoder.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── criterion_id_llm.py │ │ │ │ ├── hooks.py │ │ │ │ ├── matcher.py │ │ │ │ ├── point_features.py │ │ │ │ ├── position_encoding.py │ │ │ │ └── postprocessing.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── box_ops.py │ │ │ ├── config.py │ │ │ └── misc.py │ │ └── utils.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── gradio_web_server.py │ │ ├── register_worker.py │ │ └── test_message.py │ ├── train/ │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── llava_trainer_gd.py │ │ ├── llava_trainer_joint_train.py │ │ ├── train.py │ │ ├── train_grounding_1st.py │ │ ├── train_joint_1st.py │ │ ├── train_joint_2st.py │ │ ├── train_joint_2st_interactive_refcoco_coco_instruction.py │ │ └── train_mem.py │ └── utils.py ├── pyproject.toml ├── scripts/ │ ├── convert_sqa_to_llava.py │ ├── convert_sqa_to_llava_base_prompt.py │ ├── finetune.sh │ ├── finetune_visual_prompt.sh │ ├── merge_lora_weights.py │ └── pretrain_joint.sh └── utils/ ├── Config.py ├── __init__.py ├── arguments.py ├── constants.py ├── constants_ori.py ├── dist.py ├── distributed.py ├── misc.py ├── model.py ├── nms.py ├── prompt_engineering.py ├── utils.py └── visualizer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.err *.out *.pyc wandb /data_preparation/vis_results/ /data_preparation/vis_results_new/ /LLAVA_Stage1_Pretrained/ /work_dirs/ /llava.egg-info/ /data_preparation/data/ /vis_results/ model_worker* /playground/ *.jsonl *.pth gradio_demo/tmp_files llava_bench_results symmary_results eval_gpt4 vis_results_pdf_precision vis_results_pdf_recall output/ datasets/ output datasets *.log *.json __pycache__/ */__pycache__ */*/__pycache__ */*/*/__pycache__ */*/*/*/__pycache__ gradio_demo/examples/*.mp4 ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ 🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models ======== [[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)] [[Demo](https://llava-grounding.deepdataspace.com/ )] [[Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md)] ## :fire: News [2024/1/14] Our training code is released. [2023/12/6] Our paper is available in arxiv. ## Contents - [🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models](#-llava-grounding-grounded-visual-chat-with-large-multimodal-models) - [:fire: News](#fire-news) - [Contents](#contents) - [Install](#install) - [LLaVA-Grounding Weights](#llava-grounding-weights) - [Demo](#demo) - [Training data](#training-data) - [Flickr30k](#flickr30k) - [COCO](#coco) - [LLaVA](#llava) - [Training](#training) - [Citation](#citation) ### Install 1. Clone this repository and navigate to LLaVA-Grounding fold: ```shell git clone https://github.com/UX-Decoder/LLaVA-Grounding.git cd LLaVA-Grounding ``` 2. Install required packages: ``` conda create -n llava python=3.10 -y conda activate llava pip install --upgrade pip # enable PEP 660 support pip install -e . ``` 3. Install additional packages for training cases ``` pip install -e ".[train]" pip install flash-attn --no-build-isolation ``` 4. Install packages necessary for [OpenSeeD](https://github.com/IDEA-Research/OpenSeeD) and [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM). ### LLaVA-Grounding Weights Please check out our [Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md) for all public LLaVA-Grounding checkpoints, and the instructions on how to use the weights. ### Demo After downloading model weights, simply conduct the following commends to run demo on your own machine. ```shell CUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg path_to_vision_cfg --path_inter_cfg path_to_inter_cfg --model_path path_to_ckpt_dir # for example, after downloading weights into checkpoints/llava_grounding CUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml --path_inter_cfg configs/semsam/visual_prompt_encoder.yaml --model_path checkpoints/llava_grounding ``` Please refer to our [Online Demo](https://llava-grounding.deepdataspace.com/) for the more detailed user's guidence. ### Training data ```text data ├── flickr30k_entities │ ├── train/ │ ├── val/ │ ├── annotations │ ├──final_flickr_separateGT_train.json │ ├──final_flickr_separateGT_val.json ├── coco │ ├── train2014/ │ ├── train2017/ │ ├── panoptic_train2017/ │ ├── panoptic_semseg_train2017/ │ ├── annotations │ │ ├──instances_train2017.json │ │ ├──instances_train2017_gvc.json │ │ ├──grounded_visual_chat_data.json │ │ ├──instances_train2014_filter.json │ │ ├──panoptic_train2017_filter.json │ │ ├──grounding_train2017.json ├── llava │ ├── annotations │ ├── cap600k_brackets_all.json │ ├── llava_instruct_150k.json │ ├── llava_instruct_150k_visual_prompt.json ``` #### Flickr30k Please refer to [MDETR's pre-processed flickr30k data](https://github.com/ashkamath/mdetr/blob/main/.github/flickr.md). #### COCO Please download coco train2014 and train2017 images and panoptic segmentation and semantic segmentation data. Other annoations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data). #### LLaVA The processed annotations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data). ### Training Stage 1 ```shell bash scripts/pretrain_joint.py ``` Stage 2 ```shell bash scripts/finetune.py ``` Stage 3 ```shell bash scripts/finetune_visual_prompt.py ``` ### Citation If you find LLaVA-Grounding useful for your research and applications, please cite using this BibTeX: ```bibtex @misc{zhang2023llavagrounding, title={LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models}, author={Hao Zhang and Hongyang Li and Feng Li and Tianhe Ren and Xueyan Zou and Shilong Liu and Shijia Huang and Jianfeng Gao and Lei Zhang and Chunyuan Li and Jianwei Yang}, year={2023}, booktitle={arXiv} } @misc{liu2023llava, title={Visual Instruction Tuning}, author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae}, publisher={arXiv:2304.08485}, year={2023} } ``` ================================================ FILE: configs/openseed/openseed_swint_lang_joint.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["flickr_train","coco_2017_train_panoptic_ref_full_with_sem_seg_caption_grounding"] TEST: ["flickr_val"] CLASS_CONCAT: false SIZE_DIVISIBILITY: 32 PROPOSAL_FILES_TRAIN: [] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/openseed/openseed_swint_lang_joint_2st.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 detach_seg: False VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True COCO_ONLY: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 LLM_WEIGHT: 1.0 WEIGHT_MULTIPLIER: 1.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_instruct_train_v3","flickr_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_instruct: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_instruct" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 detach_seg: False VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True COCO_ONLY: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 LLM_WEIGHT: 1.0 WEIGHT_MULTIPLIER: 1.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_interactive_refcoco","coco_interactive","flickr_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_instruct: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_instruct" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_interactive: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/semsam/visual_prompt_encoder.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' # misc LOADER: JOINT: True KEY_DATASET: 'coco' # model MODEL: NAME: idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1 HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 512 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: None LLAMA: model_name_or_path: '/comp_robot/liushilong/data/LLAVA/LLAVA_7b' cache_dir: None model_max_length: 2048 hidden_size: 4096 tune_mm_mlp_adapter: True im_width: 16 load_fp16: False lora_r: 0 lora_alpha: 16 lora_dropout: 0.05 TEXT: ARCH: llama_encoder BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 1 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: idino_decoder_no_iou_token_partwhole_all_llm TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True PART: True pretrain: True match_loss: True GROUNDING: ENABLED: True MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: True PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: True STEP: 50 RETRIEVAL: ENABLED: True DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 IOU_WEIGHT: 1.0 LLAMA_WEIGHT: 5.0 llama_det_weight: 2.0 llama_ref_weight: 1.0 llama_region_cap_weight: 1.0 llama_img_cap_weight: 1.0 llama_gd_weight: 20.0 llama_gd_text_weight: 2.0 REFER_WEIGHT: 5.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 0 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: False INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: False LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False SAM: True PASCAL: True RE_POINT: True NUM_INTERACTIVE_TOKENS: 3 TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_2017_train_panoptic_filtrefgumdval_with_sem_seg_caption_grounding","mapillary_vistas_panoptic_train","ade20k_panoptic_train","sam_train","pascal_part_train","paco_train","partimagenet_train"]#,"sam_train","pascal_part_train"]#,"paco_train","partimagenet_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TraziningSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 8 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 8 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True part: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "part" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: datasets_os/__init__.py ================================================ from . import registration from .build import * ================================================ FILE: datasets_os/build.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import os import itertools import logging import copy from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.utils.data import torch.utils.data as torchdata import detectron2.utils.comm as comm from detectron2.data.build import ( build_batch_data_loader, load_proposals_into_dataset, trivial_batch_collator, ) from detectron2.data import MetadataCatalog from detectron2.data.catalog import DatasetCatalog from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.samplers import InferenceSampler, TrainingSampler from fvcore.common.config import CfgNode from omegaconf import DictConfig, OmegaConf from .dataset_mappers import ( COCOPanopticInteractiveDatasetMapper, FlickrNewBaselineDatasetMapper, VGNewBaselineDatasetMapper, COCOInstructGroundingDatasetMapper, COCOInterGroundingDatasetMapper, ) from .custom_dataset_dataloader import build_custom_test_loader from llava.model.openseed.utils import configurable from detectron2.utils.comm import get_world_size, is_main_process from typing import Any, Dict, List, Set class JointLoader(torchdata.IterableDataset): def __init__(self, loaders, key_dataset): dataset_names = [] for key, loader in loaders.items(): name = "{}".format(key.split('_')[0]) # name = "{}".format(key) setattr(self, name, loader) dataset_names += [name] self.dataset_names = dataset_names self.key_dataset = key_dataset def __iter__(self): for batch in zip(*[getattr(self, name) for name in self.dataset_names]): yield {key: batch[i] for i, key in enumerate(self.dataset_names)} def __len__(self): return len(getattr(self, self.key_dataset)) def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names): """ Filter out images with none annotations or only crowd annotations (i.e., images without non-crowd annotations). A common training-time preprocessing on COCO dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format, but filtered. """ num_before = len(dataset_dicts) def valid(anns): for ann in anns: if isinstance(ann, list): for instance in ann: if instance.get("iscrowd", 0) == 0: return True else: if ann.get("iscrowd", 0) == 0: return True return False dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with no usable annotations. {} images left.".format( num_before - num_after, num_after ) ) return dataset_dicts def get_detection_dataset_dicts( dataset_names, filter_empty=True, proposal_files=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: dataset_names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `dataset_names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(dataset_names, str): dataset_names = [dataset_names] assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(dataset_names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names) assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) return dataset_dicts def _test_loader_from_config(cfg, dataset_name, mapper=None): """ Uses the given `dataset_name` argument (instead of the names in cfg), because the standard practice is to evaluate each test set individually (not combining them). """ if isinstance(dataset_name, str): dataset_name = [dataset_name] dataset = get_detection_dataset_dicts( dataset_name, filter_empty=False, proposal_files=None, ) # import ipdb;ipdb.set_trace() if mapper is None: if isinstance(cfg, (DictConfig)): cfg = OmegaConf.to_container(copy.deepcopy(cfg)) mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']}) mapper = DatasetMapper(mapper_cfg, False) assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number" batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size() return { "dataset": dataset, "mapper": mapper, "num_workers": cfg['DATALOADER']['NUM_WORKERS'], "sampler": InferenceSampler(len(dataset)), "batch_size": batch_size, } @configurable(from_config=_test_loader_from_config) def build_detection_test_loader( dataset: Union[List[Any], torchdata.Dataset], *, mapper: Callable[[Dict[str, Any]], Any], sampler: Optional[torchdata.Sampler] = None, batch_size: int = 1, num_workers: int = 0, collate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> torchdata.DataLoader: """ Similar to `build_detection_train_loader`, with default batch size = 1, and sampler = :class:`InferenceSampler`. This sampler coordinates all workers to produce the exact set of all samples. Args: dataset: a list of dataset dicts, or a pytorch dataset (either map-style or iterable). They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper: a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. sampler: a sampler that produces indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, which splits the dataset across all workers. Sampler must be None if `dataset` is iterable. batch_size: the batch size of the data loader to be created. Default to 1 image per worker since this is the standard when reporting inference time in papers. num_workers: number of parallel data loading workers collate_fn: same as the argument of `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of data. Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. Examples: :: data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test") """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if isinstance(dataset, torchdata.IterableDataset): assert sampler is None, "sampler must be None if dataset is IterableDataset" else: if sampler is None: sampler = InferenceSampler(len(dataset)) return torchdata.DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=False, num_workers=num_workers, collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, ) def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None): cfg_datasets = cfg['DATASETS'] cfg_dataloader = cfg['DATALOADER'] if dataset is None: dataset = get_detection_dataset_dicts( dataset_name, filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'], proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is None: sampler_name = cfg_dataloader['SAMPLER_TRAIN'] logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) sampler = TrainingSampler(len(dataset)) return { "dataset": dataset, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'], "aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'], "num_workers": cfg_dataloader['NUM_WORKERS'], } @configurable(from_config=_train_loader_from_config) def build_detection_train_loader( dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0 ): """ Build a dataloader for object detection with some default features. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, which coordinates a random shuffle sequence across all workers. total_batch_size (int): total batch size across all workers. Batching simply puts data into a list. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: torch.utils.data.DataLoader: a dataloader. Each output from it is a ``list[mapped_element]`` of length ``total_batch_size / num_workers``, where ``mapped_element`` is produced by the ``mapper``. """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def get_config_from_name(cfg, dataset_name): # adjust config according to dataset if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'coco_interactive' in dataset_name: cfg.update(cfg['coco_interactive']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." def build_train_dataloader(cfg,tokenizer=None,data_args=None,preprocess=None,llava_cap_loader=None ): dataset_names = cfg['DATASETS']['TRAIN'] loaders = {} cfg = copy.deepcopy(cfg) for dataset_name in dataset_names: cfg = get_config_from_name(cfg, dataset_name) mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME'] if mapper_name =="flickr": mapper=FlickrNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['flickr'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="coco_instruct": mapper=COCOInstructGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['coco_instruct'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="coco_interactive": if "refcoco" in dataset_name: refcoco=True else: refcoco=False mapper=COCOInterGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess,refcoco=refcoco) if refcoco: loaders['interactiveref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) else: loaders['interactive'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="vg": mapper=VGNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['vg'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name == "coco_ref_panoptic_lsj": mapper = COCOPanopticInteractiveDatasetMapper(cfg, cfg.get('Train',True),tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['refcoco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) else: mapper = None loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) if llava_cap_loader is not None: loaders['llava_cap'] = llava_cap_loader if len(loaders) == 1 and not cfg['LOADER'].get('JOINT', False): for k, v in loaders.items(): print("number of iterations per epoch: ", v, len(loaders[k])) return list(loaders.values())[0] else: return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco')) ================================================ FILE: datasets_os/custom_dataset_dataloader.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License) import copy import logging import numpy as np import operator from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.utils.data as torchdata import json from detectron2.utils.comm import get_world_size from detectron2.utils.logger import _log_api_usage, log_first_n from detectron2.config import configurable from detectron2.data import samplers from torch.utils.data.sampler import BatchSampler, Sampler from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler, InferenceSampler from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram from detectron2.data.build import filter_images_with_only_crowd_annotations from detectron2.data.build import filter_images_with_few_keypoints from detectron2.data.build import check_metadata_consistency from detectron2.data.catalog import MetadataCatalog, DatasetCatalog from detectron2.utils import comm import itertools import math from collections import defaultdict from typing import Optional logger = logging.getLogger('detectron2.vlpart.data.custom_dataset_dataloader') def _custom_test_loader_from_config(cfg, dataset_name, mapper=None): if isinstance(dataset_name, str): dataset_name = [dataset_name] dataset = get_detection_dataset_dicts( dataset_name, filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name ] if cfg.MODEL.LOAD_PROPOSALS_TEST else None, ) if mapper is None: mapper = DatasetMapper(cfg, False) return { "dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS, "sampler": InferenceSampler(len(dataset)) if not isinstance(dataset, torchdata.IterableDataset) else None, } @configurable(from_config=_custom_test_loader_from_config) def build_custom_test_loader( dataset: Union[List[Any], torchdata.Dataset], *, mapper: Callable[[Dict[str, Any]], Any], sampler: Optional[torchdata.Sampler] = None, batch_size: int = 1, num_workers: int = 0, collate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> torchdata.DataLoader: if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if isinstance(dataset, torchdata.IterableDataset): assert sampler is None, "sampler must be None if dataset is IterableDataset" else: if sampler is None: sampler = InferenceSampler(len(dataset)) return torchdata.DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=False, num_workers=num_workers, collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, ) def trivial_batch_collator(batch): return batch def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): sampler_name = cfg.DATALOADER.SAMPLER_TRAIN if 'MultiDataset' in sampler_name: dataset_dicts = get_detection_dataset_dicts_with_source( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) else: dataset_dicts = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is not None: pass elif sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) elif sampler_name == "MultiDatasetSampler": sampler = MultiDatasetSampler( dataset_dicts, dataset_ratio = cfg.DATALOADER.DATASET_RATIO, use_rfs = cfg.DATALOADER.USE_RFS, dataset_ann = cfg.DATALOADER.DATASET_ANN, repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD, ) elif sampler_name == "RepeatFactorTrainingSampler": repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD ) sampler = RepeatFactorTrainingSampler(repeat_factors) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) return { "dataset": dataset_dicts, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, "num_workers": cfg.DATALOADER.NUM_WORKERS, 'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING, 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE, 'dataset_bs': cfg.DATALOADER.DATASET_BS, 'num_datasets': len(cfg.DATASETS.TRAIN) } @configurable(from_config=_custom_train_loader_from_config) def build_custom_train_loader( dataset, *, mapper, sampler, total_batch_size=16, aspect_ratio_grouping=True, num_workers=0, num_datasets=1, multi_dataset_grouping=False, use_diff_bs_size=False, dataset_bs=[] ): """ Modified from detectron2.data.build.build_custom_train_loader, but supports different samplers """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) if multi_dataset_grouping: return build_multi_dataset_batch_data_loader( use_diff_bs_size, dataset_bs, dataset, sampler, total_batch_size, num_datasets=num_datasets, num_workers=num_workers, ) else: return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def build_multi_dataset_batch_data_loader( use_diff_bs_size, dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0 ): """ """ world_size = get_world_size() assert ( total_batch_size > 0 and total_batch_size % world_size == 0 ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( total_batch_size, world_size ) batch_size = total_batch_size // world_size data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, num_workers=num_workers, batch_sampler=None, collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements worker_init_fn=worker_init_reset_seed, ) # yield individual mapped dict if use_diff_bs_size: return DIFFMDAspectRatioGroupedDataset( data_loader, dataset_bs, num_datasets) else: return MDAspectRatioGroupedDataset( data_loader, batch_size, num_datasets) def get_detection_dataset_dicts_with_source( dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None ): assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) for source_id, (dataset_name, dicts) in \ enumerate(zip(dataset_names, dataset_dicts)): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) for d in dicts: d['dataset_source'] = source_id if "annotations" in dicts[0]: try: class_names = MetadataCatalog.get(dataset_name).thing_classes check_metadata_consistency("thing_classes", dataset_name) print_instances_class_histogram(dicts, class_names) except AttributeError: # class names are not available for this dataset pass assert proposal_files is None dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) if min_keypoints > 0 and has_instances: dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) return dataset_dicts class MultiDatasetSampler(Sampler): def __init__( self, dataset_dicts, dataset_ratio, use_rfs, dataset_ann, repeat_threshold=0.001, seed: Optional[int] = None, ): """ """ sizes = [0 for _ in range(len(dataset_ratio))] for d in dataset_dicts: sizes[d['dataset_source']] += 1 logger.info('dataset sizes {}'.format(sizes)) self.sizes = sizes assert len(dataset_ratio) == len(sizes), \ 'length of dataset ratio {} should be equal to number if dataset {}'.format( len(dataset_ratio), len(sizes) ) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.dataset_ids = torch.tensor( [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \ for i, (r, s) in enumerate(zip(dataset_ratio, sizes))] dataset_weight = torch.cat(dataset_weight) rfs_factors = [] st = 0 for i, s in enumerate(sizes): if use_rfs[i]: if dataset_ann[i] == 'box': rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency else: rfs_func = repeat_factors_from_tag_frequency rfs_factor = rfs_func( dataset_dicts[st: st + s], repeat_thresh=repeat_threshold) rfs_factor = rfs_factor * (s / rfs_factor.sum()) else: rfs_factor = torch.ones(s) rfs_factors.append(rfs_factor) st = st + s rfs_factors = torch.cat(rfs_factors) self.weights = dataset_weight * rfs_factors self.sample_epoch_size = len(self.weights) def __iter__(self): start = self._rank yield from itertools.islice( self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) while True: ids = torch.multinomial( self.weights, self.sample_epoch_size, generator=g, replacement=True) nums = [(self.dataset_ids[ids] == i).sum().int().item() \ for i in range(len(self.sizes))] yield from ids class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): def __init__(self, dataset, batch_size, num_datasets): """ """ self.dataset = dataset self.batch_size = batch_size self._buckets = [[] for _ in range(2 * num_datasets)] def __iter__(self): for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id bucket = self._buckets[bucket_id] bucket.append(d) if len(bucket) == self.batch_size: yield bucket[:] del bucket[:] class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): def __init__(self, dataset, batch_sizes, num_datasets): """ """ self.dataset = dataset self.batch_sizes = batch_sizes self._buckets = [[] for _ in range(2 * num_datasets)] def __iter__(self): for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id bucket = self._buckets[bucket_id] bucket.append(d) if len(bucket) == self.batch_sizes[d['dataset_source']]: yield bucket[:] del bucket[:] def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh): """ """ category_freq = defaultdict(int) for dataset_dict in dataset_dicts: cat_ids = dataset_dict['pos_category_ids'] for cat_id in cat_ids: category_freq[cat_id] += 1 num_images = len(dataset_dicts) for k, v in category_freq.items(): category_freq[k] = v / num_images category_rep = { cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) for cat_id, cat_freq in category_freq.items() } rep_factors = [] for dataset_dict in dataset_dicts: cat_ids = dataset_dict['pos_category_ids'] rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) rep_factors.append(rep_factor) return torch.tensor(rep_factors, dtype=torch.float32) ================================================ FILE: datasets_os/dataset_mappers/__init__.py ================================================ from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper from .flickr_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as FlickrNewBaselineDatasetMapper from .coco_instruct_grounding_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInstructGroundingDatasetMapper from .coco_instruct_grounding_dataset_interactive_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInterGroundingDatasetMapper from .vg_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as VGNewBaselineDatasetMapper ================================================ FILE: datasets_os/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] if cfg_input['RANDOM_FLIP'] != "none": augmentation.append( T.RandomFlip( horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", vertical=cfg_input['RANDOM_FLIP'] == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: # Let's always keep mask # if not self.mask_on: # anno.pop("segmentation", None) anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects instances = utils.annotations_to_instances(annos, image_shape) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_instruct_grounding_dataset_interactive_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import random import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, refcoco=None, max_sampled=5, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.refcoco=refcoco self.max_sampled=max_sampled @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None,refcoco=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "refcoco":refcoco, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) if self.refcoco: max_sampled=min(self.max_sampled,num_conversations) sample_num=random.randint(1,max_sampled) sampled_convs=random.sample(dataset_dict['conversations'], k=sample_num) grounding_list=[] selected_conversation=[] sampled_convs[0][0][0]['value']='\n'+sampled_convs[0][0][0]['value'] for conv,gd in sampled_convs: grounding_list.extend(gd) conv[1]['value']=random.choice(conv[1]['value']) selected_conversation.extend(conv) else: rd = np.random.choice(num_conversations) selected_conversation, grounding_list = dataset_dict['conversations'][rd] dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer # num_segs = sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= assert "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0 anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id # id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor num_objs=(data_dict_conversation['input_ids']==1273).sum() grounding_list=[gd for gd in grounding_list if gd is not None] merged_grounding_list=sum(grounding_list,[]) # assert num_objs==len(merged_grounding_list) if num_objslen(merged_grounding_list): merged_grounding_list=merged_grounding_list+[merged_grounding_list[-1]]*(num_objs-len(merged_grounding_list)) merged_grounding_list=[anno_id2id[annid] for annid in merged_grounding_list] dataset_dict['grounding_index']=merged_grounding_list dataset_dict["instances"] = instances # if grounding_list is None: # dataset_dict['grounding']=False # grounding_mask=[False for _ in range(num_segs)] # dataset_dict['grounding_mask']=grounding_mask # else: # grounding_mask=[True if g is not None else False for g in grounding_list] # dataset_dict['grounding_mask']=grounding_mask # new_grounding_list=[g for g in grounding_list if g is not None] # if sum(grounding_mask)==0: # dataset_dict['grounding']=False # else: # dataset_dict['grounding']=True # if dataset_dict['grounding']: # # assert num_segs == len(grounding_list) # for grounding_id,grounding in enumerate(new_grounding_list): # if grounding is not None: # for annid in grounding: # id2class[anno_id2id[annid]].append(grounding_id) # # instances.gt_classes=id2class # dataset_dict["instances"] = instances # else: # dataset_dict['grounding'] = False # grounding_mask = [False for _ in range(num_segs)] # dataset_dict['grounding_mask'] = grounding_mask return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_instruct_grounding_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] suffix=[ "Please also provide the boxes and masks for the noun phrases in the response." , "Kindly ensure that the response includes the relevant boxes and masks for each noun phrase." , "Additionally, include the boxes and masks that match each noun phrase in the response." , "Please provide the boxes and masks that correspond to every noun phrase in your response." , "It’s important to have the boxes and masks that align with each noun phrase in the response." , "Make sure to include the appropriate boxes and masks for each noun phrase in your response." , "In your response, include the boxes and masks that pertain to each noun phrase." , "Also, supply the boxes and masks that are linked to each noun phrase in the response." , "Additionally, please furnish the boxes and masks that correspond to each noun phrase in the response." , "Don’t forget to provide the boxes and masks associated with each noun phrase in your response." , "Ensure that each noun phrase in the response has its respective boxes and masks.", ] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, replace_suffix=False, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.replace_suffix=replace_suffix @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "replace_suffix": cfg['MODEL'].get('REPLACE_SUFFIX', False), } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) selected_conversation, grounding_list = dataset_dict['conversations'][rd] dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) if self.replace_suffix: for conv in sources[0]: sf=np.random.choice(suffix) if conv['from'] == 'human': conv['value'] = conv['value'].replace('(with grounding)', sf, 1) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer num_segs = sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= if "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor if grounding_list is None: dataset_dict['grounding']=False grounding_mask=[False for _ in range(num_segs)] dataset_dict['grounding_mask']=grounding_mask else: grounding_mask=[True if g is not None else False for g in grounding_list] dataset_dict['grounding_mask']=grounding_mask new_grounding_list=[g for g in grounding_list if g is not None] if sum(grounding_mask)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: # assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(new_grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances else: dataset_dict['grounding'] = False grounding_mask = [False for _ in range(num_segs)] dataset_dict['grounding_mask'] = grounding_mask return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_interactive_panoptic_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOInteractivePanopticNewBaselineDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] if cfg.INPUT.RANDOM_FLIP != "none": augmentation.append( T.RandomFlip( horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", vertical=cfg.INPUT.RANDOM_FLIP == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInteractivePanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = filter_empty_instances_by_box(instances) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_panoptic_interactive_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import random import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances, BoxMode from detectron2.structures.boxes import pairwise_iou from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.data import MetadataCatalog from pycocotools import mask as coco_mask from utils.prompt_engineering import prompt_engineering, get_prompt_templates from llava.model.openseed.utils import configurable # from ..shapes.sampler import build_shape_sampler __all__ = ["COCOPanopticInteractiveDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ # assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks # This is specifically designed for the COCO dataset. class COCOPanopticInteractiveDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, caption_thres, # lvis, # lvis_thres, max_grounding_num, tokenizer, data_args, preprocess, # shape_sampler, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train self.caption_thres = caption_thres self.grounding = True # self.lvis = lvis # self.lvis_thres = lvis_thres self.max_grounding_num = max_grounding_num self.caption_similarity = torch.load(MetadataCatalog.get('logistic').get('caption_similarity_pth')) self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess # self.shape_sampler = shape_sampler @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) # shape_sampler = build_shape_sampler(cfg) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'], "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, # "shape_sampler": shape_sampler, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) # is_things = [COCO_CATEGORIES[idx]['isthing'] for idx in classes] classes = np.array(classes) # is_things = np.array(is_things) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) # instances.is_things = torch.tensor(is_things, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) masks = BitMasks(torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))) instances.gt_masks = masks instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks instances.gt_boxes = masks.get_bounding_boxes() if self.grounding: grounding_anno = dataset_dict['grounding_info'] if self.is_train: grounding_len = random.randint(1, self.max_grounding_num - 1) else: grounding_len = 1 if len(grounding_anno) > 0: masks_grd = [] texts_grd = [] mode = 'text' random.shuffle(grounding_anno) for ann in grounding_anno: rle = coco_mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = coco_mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2)>0 m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(m[:, :, None])[:, :, 0]==1 masks_grd += [m] # random select a sentence of a single annotation. rand_index = random.randint(0, len(ann['sentences']) - 1) texts_grd += [ann['sentences'][rand_index]['raw'].lower()] max_len = min(grounding_len, len(texts_grd)) indices = np.random.permutation(max_len) texts_grd = list(np.array(texts_grd)[indices]) masks_grd = torch.tensor(np.stack(masks_grd)[indices]) hash_grd = np.array([hash(txt) for txt in texts_grd]) gt_classes = list(range(len(texts_grd))) gt_classes = [[lb] for lb in gt_classes] label_set=texts_grd else: assert self.is_train masks_grd = instances.gt_masks.tensor mode = 'class' assert len(masks_grd) > 0 texts_grd = np.array([COCO_CATEGORIES[idx]['name'] for idx in classes]) hash_grd = np.array([hash(txt) for txt in texts_grd]) unique_hash_grd = np.unique(hash_grd) np.random.shuffle(unique_hash_grd) max_len = min(grounding_len,len(unique_hash_grd)) indices = np.random.permutation(max_len) selected_unique_hash_grd = unique_hash_grd[indices] selected_mask = np.in1d(hash_grd, selected_unique_hash_grd) texts_grd = texts_grd[selected_mask] hash_grd = hash_grd[selected_mask] masks_grd = masks_grd[selected_mask] texts_grd = [ text.replace('-other', '').replace('-merged', '').replace('-stuff', '') for text in texts_grd] label_set=list(set(texts_grd)) gt_classes=[[label_set.index(lb)] for lb in texts_grd] instances_gd = Instances(image_shape) instances_gd.gt_masks = BitMasks(masks_grd) instances_gd.gt_boxes = BitMasks(masks_grd).get_bounding_boxes() instances_gd.gt_masks=instances_gd.gt_masks.tensor instances_gd.gt_classes=gt_classes dataset_dict["instances"] = instances_gd conversations=[] for i in range(len(label_set)): if i==0: question={'from': 'human', 'value': f"\n Please detect the object according to the text {label_set[i]} (referring)."} else: question={'from': 'human', 'value': f"Please detect the object according to the text {label_set[i]} (referring)."} answer={'from': 'gpt', 'value': ' .'} conversations.append(question) conversations.append(answer) dataset_dict['conversation'] = [conversations] data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOPanopticNewBaselineDatasetMapper"] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ # assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] # if cfg.INPUT.RANDOM_FLIP != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", # vertical=cfg.INPUT.RANDOM_FLIP == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOPanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, gd_mode="inter", ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.gd_mode= gd_mode @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "gd_mode": cfg.flickr.get("gd_mode", "inter"), } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor span_set = dict() end_dict = dict() gt_classes= [] for i, info in enumerate(dataset_dict['grounding_info']): gt_classes.append([]) # if len(info['tokens_positive'])>1: # print("multi class") for j in range(len(info['tokens_positive'])): if info['tokens_positive'][j][0] in span_set: span_set[info['tokens_positive'][j][0]].append(i) else: span_set[info['tokens_positive'][j][0]] = [i] if info['tokens_positive'][j][0] in end_dict: assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1] else: end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1] gt_classes[-1].append(info['tokens_positive'][j][0]) end_dict = sorted(end_dict.items()) start2id = dict() for i, (s, e) in enumerate(end_dict): start2id[s] = i gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes] instances.gt_classes = gt_classes dataset_dict["instances"] = instances # span_list = sorted(span_set.items()) # for k, v in span_set: # for i in range(len(v)): # v[i] = positive_new_ids[v[i]] cap_pieces = [] last_e = 0 for s, e in end_dict: cap_pieces.append(dataset_dict['caption'][last_e:s]) cap_pieces.append(dataset_dict['caption'][s:e]) last_e = e cap_pieces.append(dataset_dict['caption'][last_e:]) new_cap = [] if 'end' in self.gd_mode: k=1 for i, piece in enumerate(cap_pieces): if i % 2 == 1: if self.gd_mode == 'end': piece = '' + piece + '' else: assert self.gd_mode == 'end_num' piece = f' {k} ' + piece + '' k+=1 new_cap.append(piece) new_cap = "".join(new_cap) tail = [f'{i + 1}: ' for i in range(new_cap.count(""))] tail = '; '.join(tail) new_cap += f' {tail}.' else: for i, piece in enumerate(cap_pieces): if i % 2 == 1: piece = '' + piece + '' new_cap.append(piece) new_cap = "".join(new_cap) # gt_ids = [] # for s, e in end_dict: # if len(span_set[s]) > 1: # return dataset_dict # gt_ids.append(span_set[s][0] + 1) # ground_annos = dict() # ground_annos['gt_ids'] = gt_ids # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids] # ground_annos['caption'] = new_cap question={'from': 'human', 'value': "\nPresent a compact description of the photo's key features.\n(with grounding)"} answer={'from': 'gpt', 'value': new_cap} dataset_dict['conversation'] = [[question, answer]] # sources = preprocess_multimodal( # copy.deepcopy(dataset_dict['conversation']), # self.data_args) data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) if "grounding_info" in dataset_dict: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) selected_conversation, grounding_list=dataset_dict['conversations'][rd] if grounding_list is None: dataset_dict['grounding']=False else: non_none=[1 for g in grounding_list if g is not None] if len(non_none)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: num_segs = sum([conv['value'].count('') for conv in selected_conversation]) assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_end.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor span_set = dict() end_dict = dict() gt_classes= [] for i, info in enumerate(dataset_dict['grounding_info']): gt_classes.append([]) # if len(info['tokens_positive'])>1: # print("multi class") for j in range(len(info['tokens_positive'])): if info['tokens_positive'][j][0] in span_set: span_set[info['tokens_positive'][j][0]].append(i) else: span_set[info['tokens_positive'][j][0]] = [i] if info['tokens_positive'][j][0] in end_dict: assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1] else: end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1] gt_classes[-1].append(info['tokens_positive'][j][0]) end_dict = sorted(end_dict.items()) start2id = dict() for i, (s, e) in enumerate(end_dict): start2id[s] = i gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes] instances.gt_classes = gt_classes dataset_dict["instances"] = instances # span_list = sorted(span_set.items()) # for k, v in span_set: # for i in range(len(v)): # v[i] = positive_new_ids[v[i]] cap_pieces = [] last_e = 0 for s, e in end_dict: cap_pieces.append(dataset_dict['caption'][last_e:s]) cap_pieces.append(dataset_dict['caption'][s:e]) last_e = e cap_pieces.append(dataset_dict['caption'][last_e:]) new_cap = [] for i, piece in enumerate(cap_pieces): if i % 2 == 1: piece = '' + piece + '' new_cap.append(piece) new_cap = "".join(new_cap) tail = [f'{i + 1}: ' for i in range(new_cap.count(""))] tail = '; '.join(tail) new_cap += f' {tail}.' # gt_ids = [] # for s, e in end_dict: # if len(span_set[s]) > 1: # return dataset_dict # gt_ids.append(span_set[s][0] + 1) # ground_annos = dict() # ground_annos['gt_ids'] = gt_ids # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids] # ground_annos['caption'] = new_cap question={'from': 'human', 'value': "\nPresent a compact description of the photo's key features.\n(with grounding)"} answer={'from': 'gpt', 'value': new_cap} dataset_dict['conversation'] = [[question, answer]] # sources = preprocess_multimodal( # copy.deepcopy(dataset_dict['conversation']), # self.data_args) data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOInteractivePanopticNewBaselineDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] if cfg.INPUT.RANDOM_FLIP != "none": augmentation.append( T.RandomFlip( horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", vertical=cfg.INPUT.RANDOM_FLIP == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInteractivePanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = filter_empty_instances_by_box(instances) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/inference_mapper_with_gt.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import copy import logging import numpy as np from typing import List, Optional, Union import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.structures import BitMasks, Boxes, Instances """ This file contains the default mapping that's applied to "dataset dicts". """ __all__ = ["CoCoInferenceDatasetMapper"] class CoCoInferenceDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model. This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See :doc:`/tutorials/data_loading` for details. The callable currently does the following: 1. Read the image from "file_name" 2. Applies cropping/geometric transforms to the image and annotations 3. Prepare data and annotations to Tensor and :class:`Instances` """ @configurable def __init__( self, is_train: bool, *, augmentations: List[Union[T.Augmentation, T.Transform]], image_format: str, use_instance_mask: bool = False, use_keypoint: bool = False, instance_mask_format: str = "polygon", keypoint_hflip_indices: Optional[np.ndarray] = None, precomputed_proposal_topk: Optional[int] = None, recompute_boxes: bool = False, ): """ NOTE: this interface is experimental. Args: is_train: whether it's used in training or inference augmentations: a list of augmentations or deterministic transforms to apply image_format: an image format supported by :func:`detection_utils.read_image`. use_instance_mask: whether to process instance segmentation annotations, if available use_keypoint: whether to process keypoint annotations if available instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation masks into this format. keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` precomputed_proposal_topk: if given, will load pre-computed proposals from dataset_dict and keep the top k proposals for each image. recompute_boxes: whether to overwrite bounding box annotations by computing tight bounding boxes from instance mask annotations. """ if recompute_boxes: assert use_instance_mask, "recompute_boxes requires instance masks" # fmt: off self.is_train = is_train self.augmentations = T.AugmentationList(augmentations) self.image_format = image_format self.use_instance_mask = use_instance_mask self.instance_mask_format = instance_mask_format self.use_keypoint = use_keypoint self.keypoint_hflip_indices = keypoint_hflip_indices self.proposal_topk = precomputed_proposal_topk self.recompute_boxes = recompute_boxes # fmt: on logger = logging.getLogger(__name__) mode = "training" if is_train else "inference" logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") @classmethod def from_config(cls, cfg, is_train: bool = True): augs = utils.build_augmentation(cfg, is_train) if cfg.INPUT.CROP.ENABLED and is_train: augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) recompute_boxes = cfg.MODEL.MASK_ON else: recompute_boxes = False ret = { "is_train": is_train, "augmentations": augs, "image_format": cfg.INPUT.FORMAT, "use_instance_mask": cfg.MODEL.MASK_ON, "instance_mask_format": cfg.INPUT.MASK_FORMAT, "use_keypoint": cfg.MODEL.KEYPOINT_ON, "recompute_boxes": recompute_boxes, } if cfg.MODEL.KEYPOINT_ON: ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) if cfg.MODEL.LOAD_PROPOSALS: ret["precomputed_proposal_topk"] = ( cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN if is_train else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST ) return ret def _transform_annotations(self, dataset_dict, transforms, image_shape): # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: if not self.use_instance_mask: anno.pop("segmentation", None) if not self.use_keypoint: anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations( obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices ) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances( annos, image_shape, mask_format=self.instance_mask_format ) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. if self.recompute_boxes: instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # USER: Write your own image loading if it's not from a file image = utils.read_image(dataset_dict["file_name"], format=self.image_format) utils.check_image_size(dataset_dict, image) # USER: Remove if you don't do semantic/panoptic segmentation. if "sem_seg_file_name" in dataset_dict: sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) else: sem_seg_gt = None aug_input = T.AugInput(image, sem_seg=sem_seg_gt) transforms = self.augmentations(aug_input) image, sem_seg_gt = aug_input.image, aug_input.sem_seg image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if sem_seg_gt is not None: dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) # USER: Remove if you don't use pre-computed proposals. # Most users would not need this feature. if self.proposal_topk is not None: utils.transform_proposals( dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk ) # if not self.is_train: # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # dataset_dict.pop("sem_seg_file_name", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = instances # dataset_dict["instances"] = filter_empty_instances_by_box(instances) if "annotations" in dataset_dict: self._transform_annotations(dataset_dict, transforms, image_shape) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/llava_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) dataset_dict["image_ori"]=image image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/refcoco_dataset_mapper.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. import copy import random import scipy.io import numpy as np import torch from PIL import Image from torchvision import transforms from pycocotools import mask from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from llava.model.openseed.utils import configurable __all__ = ["RefCOCODatasetMapper"] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] if cfg_input['RANDOM_FLIP'] != "none": augmentation.append( T.RandomFlip( horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", vertical=cfg_input['RANDOM_FLIP'] == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class RefCOCODatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, tfm_gens=None, image_format=None, min_size_test=None, max_size_test=None, mean=None, std=None, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens self.img_format = image_format self.is_train = is_train self.min_size_test = min_size_test self.max_size_test = max_size_test self.pixel_mean = torch.tensor(mean)[:,None,None] self.pixel_std = torch.tensor(std)[:,None,None] t = [] t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) self.transform = transforms.Compose(t) @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation if is_train: tfm_gens = build_transform_gen(cfg, is_train) else: tfm_gens = None ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT'].get('FORMAT', 'RGB'), "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], "mean": cfg['INPUT']['PIXEL_MEAN'], "std": cfg['INPUT']['PIXEL_STD'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below file_name = dataset_dict['file_name'] if self.is_train == False: image = Image.open(file_name).convert('RGB') dataset_dict['width'] = image.size[0] dataset_dict['height'] = image.size[1] image = self.transform(image) image = torch.from_numpy(np.asarray(image).copy()) dataset_dict["image_ori"] = image image = image.permute(2,0,1) dataset_dict['image'] = image grounding_anno = dataset_dict['grounding_info'] assert len(grounding_anno) > 0 masks_grd = [] texts_grd = [] boxes_grd = [] for ann in grounding_anno: rle = mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 masks_grd += [m] texts_grd.append([x['raw'].lower() for x in ann['sentences']]) boxes_grd.append(ann['bbox']) # xywh masks_grd = torch.from_numpy(np.stack(masks_grd)) boxes_grd = torch.tensor(boxes_grd) groundings = {'masks': masks_grd, 'texts': texts_grd, 'boxes': boxes_grd} dataset_dict["groundings"] = groundings else: image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"] = image image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) grounding_anno = dataset_dict['grounding_info'] assert len(grounding_anno) > 0 masks_grd = [] texts_grd = [] boxes_grd = [] hash_grd = [] for ann in grounding_anno: rle = mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(m[:,:,None])[:,:,0] masks_grd += [m] rand_id = random.randint(0, len(ann['sentences'])-1) texts_grd.append(ann['sentences'][rand_id]['raw'].lower()) hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower())) masks_grd = torch.from_numpy(np.stack(masks_grd)) boxes_grd = torch.tensor(boxes_grd) groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'} dataset_dict["groundings"] = groundings return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/vg_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, max_grounding_num, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.max_grounding_num = max_grounding_num self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) dataset_dict["image_ori"]=image image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict assert "annotations" in dataset_dict for obj in dataset_dict['annotations']: obj["bbox_mode"] = BoxMode.XYWH_ABS # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["annotations"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") instances.captions=[ann['caption'] for ann in dataset_dict["annotations"]] # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor # dataset_dict["instances"] = instances num_instances = len(instances) indices = list(range(num_instances)) import random if self.is_train: grounding_len = random.randint(1, self.max_grounding_num - 1) else: grounding_len = 1 random.shuffle(indices) indices = indices[:grounding_len] texts_grd = [instances.captions[i] for i in indices] gt_classes = list(range(len(texts_grd))) gt_classes = [[lb] for lb in gt_classes] label_set = texts_grd grounding_instances = Instances(image_size=(h, w)) grounding_instances.gt_boxes = instances.gt_boxes[indices] grounding_instances.gt_classes = gt_classes dataset_dict["instances"]=grounding_instances conversations=[] for i in range(len(label_set)): if i==0: question={'from': 'human', 'value': f"\n Please detect the object according to the text {label_set[i]} (referring)."} else: question={'from': 'human', 'value': f"Please detect the object according to the text {label_set[i]} (referring)."} answer={'from': 'gpt', 'value': ' .'} conversations.append(question) conversations.append(answer) dataset_dict['conversation'] = [conversations] data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/refer.py ================================================ __author__ = 'licheng' """ This interface provides access to four datasets: 1) refclef 2) refcoco 3) refcoco+ 4) refcocog split by unc and google The following API functions are defined: REFER - REFER api class getRefIds - get ref ids that satisfy given filter conditions. getAnnIds - get ann ids that satisfy given filter conditions. getImgIds - get image ids that satisfy given filter conditions. getCatIds - get category ids that satisfy given filter conditions. loadRefs - load refs with the specified ref ids. loadAnns - load anns with the specified ann ids. loadImgs - load images with the specified image ids. loadCats - load category names with the specified category ids. getRefBox - get ref's bounding box [x, y, w, h] given the ref_id showRef - show image, segmentation or box of the referred object with the ref getMask - get mask and area of the referred object given ref showMask - show mask of the referred object given ref """ from doctest import REPORT_ONLY_FIRST_FAILURE import sys import os.path as osp import json import pickle import time import itertools import skimage.io as io import matplotlib.pyplot as plt from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon, Rectangle from pprint import pprint import numpy as np from pycocotools import mask # import cv2 # from skimage.measure import label, regionprops class REFER: def __init__(self, data_root, dataset='refcoco', splitBy='unc'): # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog # also provide dataset name and splitBy information # e.g., dataset = 'refcoco', splitBy = 'unc' print('loading dataset {} into memory...'.format(dataset)) self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) self.DATA_DIR = osp.join(data_root, dataset) if dataset in ['refcoco', 'refcoco+', 'refcocog']: self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') elif dataset == 'refclef': self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') else: print('No refer dataset is called [{}]'.format(dataset)) sys.exit() # load refs from data/dataset/refs(dataset).json tic = time.time() ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') self.data = {} self.data['dataset'] = dataset self.data['refs'] = pickle.load(open(ref_file, 'rb')) # load annotations from data/dataset/instances.json instances_file = osp.join(self.DATA_DIR, 'instances.json') instances = json.load(open(instances_file, 'r')) self.data['images'] = instances['images'] self.data['annotations'] = instances['annotations'] self.data['categories'] = instances['categories'] # create index self.createIndex() print('DONE (t=%.2fs)'.format(time.time()-tic)) def createIndex(self): # create sets of mapping # 1) Refs: {ref_id: ref} # 2) Anns: {ann_id: ann} # 3) Imgs: {image_id: image} # 4) Cats: {category_id: category_name} # 5) Sents: {sent_id: sent} # 6) imgToRefs: {image_id: refs} # 7) imgToAnns: {image_id: anns} # 8) refToAnn: {ref_id: ann} # 9) annToRef: {ann_id: ref} # 10) catToRefs: {category_id: refs} # 11) sentToRef: {sent_id: ref} # 12) sentToTokens: {sent_id: tokens} print('creating index...') # fetch info from instances Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} for ann in self.data['annotations']: Anns[ann['id']] = ann imgToAnns[ann['image_id']] = imgToAnns.get( ann['image_id'], []) + [ann] for img in self.data['images']: Imgs[img['id']] = img for cat in self.data['categories']: Cats[cat['id']] = cat['name'] # fetch info from refs Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} Sents, sentToRef, sentToTokens = {}, {}, {} for ref in self.data['refs']: # ids ref_id = ref['ref_id'] ann_id = ref['ann_id'] category_id = ref['category_id'] image_id = ref['image_id'] # add mapping related to ref Refs[ref_id] = ref imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] refToAnn[ref_id] = Anns[ann_id] annToRef[ann_id] = ref # add mapping of sent for sent in ref['sentences']: Sents[sent['sent_id']] = sent sentToRef[sent['sent_id']] = ref sentToTokens[sent['sent_id']] = sent['tokens'] # create class members self.Refs = Refs self.Anns = Anns self.Imgs = Imgs self.Cats = Cats self.Sents = Sents self.imgToRefs = imgToRefs self.imgToAnns = imgToAnns self.refToAnn = refToAnn self.annToRef = annToRef self.catToRefs = catToRefs self.sentToRef = sentToRef self.sentToTokens = sentToTokens print('index created.') def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): image_ids = image_ids if type(image_ids) == list else [image_ids] cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: refs = self.data['refs'] else: if not len(image_ids) == 0: refs = [self.imgToRefs[image_id] for image_id in image_ids] else: refs = self.data['refs'] if not len(cat_ids) == 0: refs = [ref for ref in refs if ref['category_id'] in cat_ids] if not len(ref_ids) == 0: refs = [ref for ref in refs if ref['ref_id'] in ref_ids] if not len(split) == 0: if split in ['testA', 'testB', 'testC']: # we also consider testAB, testBC, ... refs = [ref for ref in refs if split[-1] in ref['split']] elif split in ['testAB', 'testBC', 'testAC']: # rarely used I guess... refs = [ref for ref in refs if ref['split'] == split] elif split == 'test': refs = [ref for ref in refs if 'test' in ref['split']] elif split == 'train' or split == 'val': refs = [ref for ref in refs if ref['split'] == split] else: print('No such split [{}]'.format(split)) sys.exit() ref_ids = [ref['ref_id'] for ref in refs] return ref_ids def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): image_ids = image_ids if type(image_ids) == list else [image_ids] cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: ann_ids = [ann['id'] for ann in self.data['annotations']] else: if not len(image_ids) == 0: lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns] anns = list(itertools.chain.from_iterable(lists)) else: anns = self.data['annotations'] if not len(cat_ids) == 0: anns = [ann for ann in anns if ann['category_id'] in cat_ids] ann_ids = [ann['id'] for ann in anns] if not len(ref_ids) == 0: ids = set(ann_ids).intersection( set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) return ann_ids def getImgIds(self, ref_ids=[]): ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if not len(ref_ids) == 0: image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) else: image_ids = self.Imgs.keys() return image_ids def getCatIds(self): return self.Cats.keys() def loadRefs(self, ref_ids=[]): if type(ref_ids) == list: return [self.Refs[ref_id] for ref_id in ref_ids] elif type(ref_ids) == int: return [self.Refs[ref_ids]] def loadAnns(self, ann_ids=[]): if type(ann_ids) == list: return [self.Anns[ann_id] for ann_id in ann_ids] elif type(ann_ids) == int or type(ann_ids) == unicode: return [self.Anns[ann_ids]] def loadImgs(self, image_ids=[]): if type(image_ids) == list: return [self.Imgs[image_id] for image_id in image_ids] elif type(image_ids) == int: return [self.Imgs[image_ids]] def loadCats(self, cat_ids=[]): if type(cat_ids) == list: return [self.Cats[cat_id] for cat_id in cat_ids] elif type(cat_ids) == int: return [self.Cats[cat_ids]] def getRefBox(self, ref_id): ref = self.Refs[ref_id] ann = self.refToAnn[ref_id] return ann['bbox'] # [x, y, w, h] def showRef(self, ref, seg_box='seg'): ax = plt.gca() # show image image = self.Imgs[ref['image_id']] I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) ax.imshow(I) # show refer expression for sid, sent in enumerate(ref['sentences']): print('{}. {}'.format(sid+1, sent['sent'])) # show segmentations if seg_box == 'seg': ann_id = ref['ann_id'] ann = self.Anns[ann_id] polygons = [] color = [] c = 'none' if type(ann['segmentation'][0]) == list: # polygon used for refcoco* for seg in ann['segmentation']: poly = np.array(seg).reshape((len(seg)/2, 2)) polygons.append(Polygon(poly, True, alpha=0.4)) color.append(c) p = PatchCollection(polygons, facecolors=color, edgecolors=( 1, 1, 0, 0), linewidths=3, alpha=1) ax.add_collection(p) # thick yellow polygon p = PatchCollection(polygons, facecolors=color, edgecolors=( 1, 0, 0, 0), linewidths=1, alpha=1) ax.add_collection(p) # thin red polygon else: # mask used for refclef rle = ann['segmentation'] m = mask.decode(rle) img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.array([2.0, 166.0, 101.0])/255 for i in range(3): img[:, :, i] = color_mask[i] ax.imshow(np.dstack((img, m*0.5))) # show bounding-box elif seg_box == 'box': ann_id = ref['ann_id'] ann = self.Anns[ann_id] bbox = self.getRefBox(ref['ref_id']) box_plot = Rectangle( (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) ax.add_patch(box_plot) def getMask(self, ref): # return mask, area and mask-center ann = self.refToAnn[ref['ref_id']] image = self.Imgs[ref['image_id']] if type(ann['segmentation'][0]) == list: # polygon rle = mask.frPyObjects( ann['segmentation'], image['height'], image['width']) else: rle = ann['segmentation'] m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 # compute area area = sum(mask.area(rle)) # should be close to ann['area'] return {'mask': m, 'area': area} # # position # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) # # mass position (if there were multiple regions, we use the largest one.) # label_m = label(m, connectivity=m.ndim) # regions = regionprops(label_m) # if len(regions) > 0: # largest_id = np.argmax(np.array([props.filled_area for props in regions])) # largest_props = regions[largest_id] # mass_y, mass_x = largest_props.centroid # else: # mass_x, mass_y = position_x, position_y # # if centroid is not in mask, we find the closest point to it from mask # if m[mass_y, mass_x] != 1: # print 'Finding closes mask point ...' # kernel = np.ones((10, 10),np.uint8) # me = cv2.erode(m, kernel, iterations = 1) # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style # points = np.array(points) # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) # id = np.argsort(dist)[0] # mass_y, mass_x = points[id] # # return # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} # # show image and mask # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) # plt.figure() # plt.imshow(I) # ax = plt.gca() # img = np.ones( (m.shape[0], m.shape[1], 3) ) # color_mask = np.array([2.0,166.0,101.0])/255 # for i in range(3): # img[:,:,i] = color_mask[i] # ax.imshow(np.dstack( (img, m*0.5) )) # plt.show() def showMask(self, ref): M = self.getMask(ref) msk = M['mask'] ax = plt.gca() ax.imshow(msk) if __name__ == '__main__': refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg', dataset='refcocog', splitBy='google') ref_ids = refer.getRefIds() print(len(ref_ids)) print(len(refer.Imgs)) print(len(refer.imgToRefs)) ref_ids = refer.getRefIds(split='train') print('There are {} training referred objects.' % len(ref_ids)) for ref_id in ref_ids: ref = refer.loadRefs(ref_id)[0] if len(ref['sentences']) < 2: continue pprint(ref) print('The label is {}.'.format(refer.Cats[ref['category_id']])) # plt.figure() # refer.showRef(ref, seg_box='box') # plt.show() # plt.figure() # refer.showMask(ref) # plt.show() ================================================ FILE: datasets_os/registration/__init__.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. from . import ( register_coco_panoptic_annos_grounding_interactive, register_coco_instruct_grounding_dataset, register_flickr_dataset, # register_vg_dataset, ) ================================================ FILE: datasets_os/registration/register_coco_instruct_grounding_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets import load_sem_seg from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.utils.file_io import PathManager import pycocotools.mask as mask_util _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { "coco_instruct_train_v3": ( "coco/train2014", # image_root "coco/annotations/instances_train2017_gvc.json", # annot_root "llava/annotations/grounded_visual_chat_data.json", ), "coco_interactive": ( "coco/train2014", # image_root "coco/annotations/instances_train2014_filter.json", # annot_root "llava/annotations/llava_instruct_150k_visual_prompt.json", ), "coco_interactive_refcoco": ( "coco/train2017", # image_root "coco/annotations/instances_train2017_refcoco.json", # annot_root "coco/annotations/grounding_train2017_instruct.json", ), } def get_metadata(): meta = {} return meta def load_coco_json(image_root, annot_json,conversation, metadata): """ Args: image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". Returns: list[dict]: a list of dicts in Detectron2 standard format. (See `Using Custom Datasets `_ ) """ with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) imgid2image = {} for image in json_info["images"]: image_id = image["id"] imgid2image[image_id] = image for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) segm = grd_ann.get("segmentation", None) if segm: # either list[list[float]] or dict(RLE) if isinstance(segm, dict): if isinstance(segm["counts"], list): # convert to compressed RLE segm = mask_util.frPyObjects(segm, *segm["size"]) grd_ann["segmentation"] = segm grd_dict[image_id].append(grd_ann) conv_dict = collections.defaultdict(list) with open(conversation) as f: data = json.load(f) for d in data: image_id = int(d['id']) if 'gd_ls' not in d: d['gd_ls']=None if 'q_gd_ls' in d: conv_dict[image_id].append((d['conversations'],d['q_gd_ls'])) else: conv_dict[image_id].append((d['conversations'], d['gd_ls'])) ret = [] for d in data: image_id = int(d['id']) image= imgid2image[image_id] image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] if image_id in conv_dict and len(conv_dict[image_id])>0: ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "conversations": conv_dict[image_id], } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_coco( name, metadata, image_root, annot_json,conversation): DatasetCatalog.register( name, lambda: load_coco_json(image_root, annot_json,conversation, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_coco(root): for ( prefix, (image_root, annot_root,conversation_path), ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): register_coco( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root, annot_root), conversation_path, ) _root = os.getenv("DATASET", "datasets") register_all_coco(_root) ================================================ FILE: datasets_os/registration/register_coco_panoptic_annos_grounding_interactive.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets import load_sem_seg from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { "coco_2017_train_panoptic_ref_full": ( # This is the original panoptic annotation directory "coco/panoptic_train2017", "coco/annotations/panoptic_train2017_filter.json", "coco/panoptic_semseg_train2017", "coco/annotations/grounding_train2017.json", ), } def get_metadata(): meta = {} # The following metadata maps contiguous id from [0, #thing categories + # #stuff categories) to their names and colors. We have to replica of the # same name and color under "thing_*" and "stuff_*" because the current # visualization function in D2 handles thing and class classes differently # due to some heuristic used in Panoptic FPN. We keep the same naming to # enable reusing existing visualization functions. thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] stuff_classes = [k["name"] for k in COCO_CATEGORIES] stuff_colors = [k["color"] for k in COCO_CATEGORIES] meta["thing_classes"] = thing_classes meta["thing_colors"] = thing_colors meta["stuff_classes"] = stuff_classes meta["stuff_colors"] = stuff_colors # Convert category id for training: # category id: like semantic segmentation, it is the class id for each # pixel. Since there are some classes not used in evaluation, the category # id is not always contiguous and thus we have two set of category ids: # - original category id: category id in the original dataset, mainly # used for evaluation. # - contiguous category id: [0, #classes), in order to train the linear # softmax classifier. thing_dataset_id_to_contiguous_id = {} stuff_dataset_id_to_contiguous_id = {} for i, cat in enumerate(COCO_CATEGORIES): if cat["isthing"]: thing_dataset_id_to_contiguous_id[cat["id"]] = i # else: # stuff_dataset_id_to_contiguous_id[cat["id"]] = i # in order to use sem_seg evaluator stuff_dataset_id_to_contiguous_id[cat["id"]] = i meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id return meta def load_coco_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, grounding_file, meta): """ Args: image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". Returns: list[dict]: a list of dicts in Detectron2 standard format. (See `Using Custom Datasets `_ ) """ def _convert_category_id(segment_info, meta): if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ segment_info["category_id"] ] segment_info["isthing"] = True else: segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ segment_info["category_id"] ] segment_info["isthing"] = False return segment_info with PathManager.open(json_file) as f: json_info = json.load(f) with PathManager.open(grounding_file) as f: grounding_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) for grd_ann in grounding_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for ann in json_info["annotations"]: image_id = int(ann["image_id"]) # TODO: currently we assume image and label has the same filename but # different extension, and images have extension ".jpg" for COCO. Need # to make image extension a user-provided argument if we extend this # function to support other COCO-like datasets. image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg") label_file = os.path.join(gt_dir, ann["file_name"]) sem_label_file = os.path.join(semseg_dir, ann["file_name"]) segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]] grounding_anno = grd_dict[image_id] if image_id in grd_dict else [] ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "pan_seg_file_name": label_file, "sem_seg_file_name": sem_label_file, "segments_info": segments_info, } ) assert len(ret), f"No images found in {image_dir}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"] assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"] return ret def register_coco_panoptic_annos_caption_grounding_sem_seg( name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, grounding_root, instances_json ): panoptic_name = '_'.join(name.split('_')[0:4]) delattr(MetadataCatalog.get(panoptic_name), "thing_classes") delattr(MetadataCatalog.get(panoptic_name), "thing_colors") MetadataCatalog.get(panoptic_name).set( thing_classes=metadata["thing_classes"], thing_colors=metadata["thing_colors"], # thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"], ) # the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg" semantic_name = name + "_with_sem_seg_caption_grounding" DatasetCatalog.register( semantic_name, lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, sem_seg_root, grounding_root, metadata), ) MetadataCatalog.get(semantic_name).set( sem_seg_root=sem_seg_root, panoptic_root=panoptic_root, image_root=image_root, panoptic_json=panoptic_json, json_file=instances_json, evaluator_type="coco_panoptic_seg_interactive", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_coco_panoptic_annos_caption_grounding_sem_seg(root): for ( prefix, (panoptic_root, panoptic_json, semantic_root, grounding_root), ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): prefix_instances = '_'.join(prefix.split('_')[0:3]) instances_meta = MetadataCatalog.get(prefix_instances) image_root, instances_json = instances_meta.image_root, instances_meta.json_file # image_root = image_root.replace('datasets', root) register_coco_panoptic_annos_caption_grounding_sem_seg( prefix, get_metadata(), image_root, os.path.join(root, panoptic_root), os.path.join(root, panoptic_json), os.path.join(root, semantic_root), os.path.join(root, grounding_root), os.path.join(root, instances_json), ) _root = os.getenv("DATASET", "datasets") register_all_coco_panoptic_annos_caption_grounding_sem_seg(_root) ================================================ FILE: datasets_os/registration/register_flickr_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS = { "flickr_val": ( "flickr30k_entities/val", # image_root "final_flickr_separateGT_val.json", # # anno_path ), "flickr_train": ( "flickr30k_entities/train", # image_root "final_flickr_separateGT_train.json", # # anno_path ), } def get_metadata(): meta = {} return meta def load_flickr_json(image_root, annot_json, metadata): with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) # for grd_ann in json_info['annotations']: # image_id = int(grd_ann["image_id"]) # grd_dict[image_id].append(grd_ann) for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for image in json_info["images"]: image_id = int(image["id"]) caption=image['caption'] image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "caption": caption, } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_flickr( name, metadata, image_root, annot_json): DatasetCatalog.register( name, lambda: load_flickr_json(image_root, annot_json, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_flickr(root,anno_root): for ( prefix, (image_root, anno_path), ) in _PREDEFINED_SPLITS.items(): register_flickr( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root,anno_root, anno_path), ) _root = os.getenv("DATASET", "datasets") ann_root = os.getenv("Flickr", "flickr30k_entities/annotations") register_all_flickr(_root,ann_root) ================================================ FILE: datasets_os/registration/register_vg_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS = { "vg_train": ( "vg/images/", # image_root "train.json", # anno_path ), } def get_metadata(): meta = {} return meta def load_vg_json(image_root, annot_json, metadata): with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for image in json_info["images"]: image_id = int(image["id"]) image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] ret.append( { "file_name": image_file, "image_id": image_id, "annotations": grounding_anno, } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_vg( name, metadata, image_root, annot_json): DatasetCatalog.register( name, lambda: load_vg_json(image_root, annot_json, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_vg(root,anno_root): for ( prefix, (image_root, anno_path), ) in _PREDEFINED_SPLITS.items(): register_vg( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root,anno_root, anno_path), ) _root = os.getenv("DATASET", "datasets") anno_root = os.getenv("VG", "vg/annotations/") register_all_vg(_root,anno_root) ================================================ FILE: datasets_os/semseg_loader.py ================================================ from PIL import Image import scipy.io import numpy as np def load_semseg(filename, loader_type): if loader_type == 'PIL': semseg = np.array(Image.open(filename), dtype=np.int) elif loader_type == 'MAT': semseg = scipy.io.loadmat(filename)['LabelMap'] return semseg ================================================ FILE: docs/MODEL_ZOO.md ================================================ # LLaVA-Grounding Checkpoints We will continuously update the model zoo. | Model Name | LLM version | Model Config | Weights | |------------|:---------------:|:-------------:|:-----------:| | LLaVA_Grounding_v0_7b | vicuna-v0-7b | [[grounding-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml), [visual-prompt-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/semsam/visual_prompt_encoder.yaml)](0.3B in total) | [HuggingFace](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp) | ================================================ FILE: gradio_demo/LLaVA_G_Demo.py ================================================ import gradio as gr import os import cv2 import torch import numpy as np from llava.eval.LLaVA_G_Eval import Evaluator_MM_Inter from llava import conversation as conversation_lib from llava.mm_utils import tokenizer_image_token from llava.constants import DEFAULT_IMAGE_TOKEN def get_image_name(dir_save="./gradio_demo/tmp_files", prefix="click_img_"): import os files = os.listdir(dir_save) file_orders = [int(file.split(".")[0][len(prefix):]) for file in files if file.endswith(".jpg") and file.startswith(prefix)] if len(file_orders) == 0: return os.path.join(dir_save, prefix + "0.jpg") else: return os.path.join(dir_save, prefix + str(max(file_orders) + 1) + ".jpg") def preprocess_multi_conv( sources, tokenizer, has_image = False ): conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} conv.messages = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conv_prompt = conv.get_prompt() conv_prompt = "ASSISTANT: ".join(conv_prompt.split("ASSISTANT: ")[:-1]) + "ASSISTANT:" conv_prompt = conv_prompt.replace("", "") conversations = [conv_prompt] print("Input Prompt: ", conv_prompt) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO return dict( input_ids=input_ids, labels=targets, ) def filter_empty_box_mask(text, boxes_image, masks_image): def extract_text(sentence): # Use regular expression to find and extract the text and number import re pattern = r"| " cleaned_text = re.sub(pattern, '', sentence) return cleaned_text if len(boxes_image) == 0: return text, boxes_image, masks_image else: sub_texts = text.split(" ") sub_texts_filtered = [] boxes_image_filtered = [] masks_image_filtered = [] for box_per_gd, mask_per_gd, text_per_gd in zip(boxes_image, masks_image, sub_texts): text_per_gd += " " ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0) if len(ind_nonempty_box[0]) < box_per_gd.shape[0]: # empty box encountered if len(ind_nonempty_box[0]) == 0: text_per_gd = " " + " ".join(extract_text(text_per_gd).split()) sub_texts_filtered.append(text_per_gd) # box is desperated continue else: box_per_gd = box_per_gd[ind_nonempty_box] mask_per_gd = mask_per_gd[ind_nonempty_box] boxes_image_filtered.append(box_per_gd) masks_image_filtered.append(mask_per_gd) sub_texts_filtered.append(text_per_gd) else: boxes_image_filtered.append(box_per_gd) masks_image_filtered.append(mask_per_gd) sub_texts_filtered.append(text_per_gd) sub_texts_filtered.append(sub_texts[-1]) text_filtered = "".join(sub_texts_filtered) return text_filtered, boxes_image_filtered, masks_image_filtered class InferenceDemo(object): def __init__(self, model_path, path_vision_cfg, path_inter_cfg, ) -> None: self.model_backend = Evaluator_MM_Inter( model_path=model_path, path_vision_model_cfg=path_vision_cfg, path_inter_model_cfg =path_inter_cfg, ) self.model_backend.data_mapper.preprocess = preprocess_multi_conv def hitory2datadict(self, history, text): def filter_valid_conversations(history): def delete_color(text): import re pattern = re.compile(r'(.*?)', re.DOTALL) clean_text = pattern.sub(r'\1', text) return clean_text valid_conversations = history[3:] valid_conversations = [aa for aa in valid_conversations if not (None in aa)] valid_conversations = [[delete_color(aa[0]), delete_color(aa[1])] for aa in valid_conversations] return valid_conversations valid_conversations = filter_valid_conversations(history) dataset_dict = { "file_name": history[1][0][0], "image_id": 0, "question_id": 0, } dataset_dict['conversations'] = [] for valid_conv in valid_conversations: conv = [ { "from": "human", "value": valid_conv[0] }, { "from": "gpt", "value": valid_conv[1] } ] dataset_dict['conversations'].append([conv, None]) conv = [ { "from": "human", "value": text }, { "from": "gpt", "value": "Placeholder." } ] dataset_dict['conversations'].append([conv, None]) dataset_dict['conversations'][0][0][0]["value"] = DEFAULT_IMAGE_TOKEN + " " + dataset_dict['conversations'][0][0][0]["value"] return dataset_dict def inference(self, data_dict): # TODO: Implement data_mapper. data_dict = self.model_backend.data_mapper(data_dict)[0] # device = self.model_backend.model.device for key, value in data_dict.items(): if isinstance(value, torch.Tensor): data_dict[key] = value.to(device) response_text, response_boxes, response_mask, mask_inter = self.model_backend.evaluate_sample([data_dict]) # response_text, response_boxes, response_mask = filter_empty_box_mask(response_text, response_boxes, response_mask) return response_text, response_boxes, response_mask, mask_inter def generate_distinct_colors(count): import colorsys import random random.seed(0) hues = [i/count for i in range(count)] random.shuffle(hues) colors = [] for hue in hues: rgb = colorsys.hsv_to_rgb(hue, 1, 1) rgb = tuple(int(val * 255) for val in rgb) colors.append(rgb) return colors def add_text(history, text, image, threshold_slider, temporature_slider, interaction_selector): # add a text to history stream. and leave the response as None for you to fill in bot. def response2stream(response, question): return [[question, response]] def post_process_text_response(text): def find_start_idxes(sentence, word): window_size = len(word) start_indexes = [] assert len(sentence) > window_size if sentence == window_size: return [0] for start_index in range(len(sentence) - window_size): if sentence[start_index: start_index + window_size] == word: start_indexes.append(start_index) return start_indexes def add_color_to_text(obj_id, text): color = colors[obj_id] text = f"{text}" return text def format_sentence(splitted_sentence): joint_sentence = " ".join(splitted_sentence) return joint_sentence def extract_text(sentence): import re pattern = r"|" cleaned_text = re.sub(pattern, '', sentence) return cleaned_text text_pure = "" seg_start_index = find_start_idxes(text, "") if len(seg_start_index) > 0: count_obj = 0 subtexts = text.split(" ") for subtext in subtexts: if "" in subtext: start_idx = find_start_idxes(subtext, "")[0] text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())]) text_ = extract_text(subtext[start_idx:]) text_pure += add_color_to_text(count_obj, text_) count_obj += 1 else: text_pure = format_sentence([text_pure, format_sentence(subtext.split())]) else: text_pure = text return text_pure def post_process_gd_response(path_ori_image, gd_results_per_image): def unresize_box(box, width, height): ratio = min(width, height) / max(width, height) if width > height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 1] = box[:, 1] / ratio box[:, 3] = box[:, 3] / ratio elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 0] = box[:, 0] / ratio box[:, 2] = box[:, 2] / ratio return box image = cv2.imread(path_ori_image) height, width = image.shape[:2] gd_results_per_image = [unresize_box(aa.detach().cpu(), width, height) for aa in gd_results_per_image] for gd_id, gd_result in enumerate(gd_results_per_image): bboxes = gd_result.cpu().tolist() for bbox in bboxes: bbox = [int(bbox[0]*width), int(bbox[1]*height), int(bbox[2]*width), int(bbox[3]*height)] cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), colors[gd_id][::-1], 2) path_save = get_image_name(prefix="grounding_img_") cv2.imwrite(path_save, image) return (path_save, ) def post_process_masks(path_ori_image, mask_inter, path_gd_image, masks_gd, loc_inter, inter_type): def unresize_mask(mask, width, height): import torch.nn.functional as F if width >= height: # then the height dimension is padded, the y coordinates should be divided by ratio mask = F.interpolate((mask[None, ...]).float(), size=[width, width], mode="nearest")[0] mask = mask[:, :height] elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio mask = F.interpolate((mask[None, ...]).float(), size=[height, height], mode="nearest")[0] mask = mask[:, :, :width] return mask def unnormalize_inter(mask, loc): height, width, _ = mask.shape loc_x_mean, loc_y_mean, loc_w, loc_h = loc if height >= width: loc_x_mean = loc_x_mean / (width/height) loc_w = loc_w / (width/height) else: loc_y_mean = loc_y_mean / (height/width) loc_h = loc_h / (height/width) return [loc_x_mean-loc_w/2, loc_y_mean-loc_h/2, loc_x_mean+loc_w/2, loc_y_mean+loc_h/2] image = cv2.imread(path_ori_image) gd_image = cv2.imread(path_gd_image) returns = [] if not (mask_inter is None): mask_ = (mask_inter[0][..., None] > 0).float().cpu().numpy() mask_ = cv2.resize(mask_, (max(image.shape[0], image.shape[1]), max(image.shape[0], image.shape[1]))) mask_ = mask_[:image.shape[0], :image.shape[1]] mask_ = mask_[..., None] * np.array([155, 155, 155])[None, None, :] image = (image * 0.5 + mask_ * 0.5).astype(np.uint8) if inter_type.lower() == "box": loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])] thickness = max(int(max(image.shape) / 1000 * 5), 1) print(thickness) cv2.rectangle(image, (int(loc_inter_unnormalized[0][0]*image.shape[1]), int(loc_inter_unnormalized[0][1]*image.shape[0])), (int((loc_inter_unnormalized[0][2])*image.shape[1]), int((loc_inter_unnormalized[0][3])*image.shape[0])), (255, 255, 255), thickness) elif inter_type.lower() == "click": loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])] thickness = max(int(max(image.shape) / 1000 * 10), 1) print(thickness) cv2.circle(image, (int(((loc_inter_unnormalized[0][0] + loc_inter_unnormalized[0][2])/2)*image.shape[1]), int(((loc_inter_unnormalized[0][1] + loc_inter_unnormalized[0][3])/2)*image.shape[0])), thickness, (255, 255, 255), thickness=-1) path_save = get_image_name(prefix="seg_inter_") cv2.imwrite(path_save, image) returns.append((path_save, )) else: returns.append(None) if not (masks_gd is None): height, width = image.shape[:2] masks_gd = [unresize_mask(aa, width, height) for aa in masks_gd] colored_mask = torch.zeros((3, height, width), dtype=torch.long) for gd_id, gd_mask_result in enumerate(masks_gd): gd_mask_result = gd_mask_result.sum(dim=0, keepdim=True) colored_mask[:, gd_mask_result[0] > 0.5] = torch.tensor(colors[gd_id][::-1])[:, None] gd_image = (gd_image * 0.6 + colored_mask.permute(1,2,0).numpy() * 0.4).astype(np.uint8) path_save_gd = get_image_name(prefix="seg_gd_") cv2.imwrite(path_save_gd, gd_image) returns.append((path_save_gd, )) else: returns.append(None) return returns def mask2point(mask, inter_type): height, width = mask.shape[:2] ys, xs = np.where(mask[..., 0] == 255) if inter_type.lower() == "click": loc_x = xs.mean() loc_y = ys.mean() loc_x = loc_x / width loc_y = loc_y / height if height >= width: loc_x = loc_x * (width/height) else: loc_y = loc_y * (height/width) return torch.tensor([[loc_x, loc_y, 0.006, 0.006]]) elif inter_type.lower() == "box": loc_x_min = xs.min() / width loc_x_max = xs.max() / width loc_y_min = ys.min() / height loc_y_max = ys.max() / height if height >= width: loc_x_min = loc_x_min * (width/height) loc_x_max = loc_x_max * (width/height) else: loc_y_min = loc_y_min * (height/width) loc_y_max = loc_y_max * (height/width) width = loc_x_max - loc_x_min height = loc_y_max - loc_y_min return torch.tensor([[(loc_x_min + loc_x_max)/2, (loc_y_min + loc_y_max)/2, width, height]]) if len(history) < 3: response_text = "Please upload an image first." else: loc_inter = mask2point(image["mask"], interaction_selector) is_interactive = torch.isnan(loc_inter[0]).sum() == 0 if is_interactive: input_data_dict = our_chatbot.hitory2datadict(history, text) input_data_dict["points"] = loc_inter input_data_dict["mode_inter"] = interaction_selector else: input_data_dict = our_chatbot.hitory2datadict(history, text) input_data_dict["points"] = None input_data_dict["mode_inter"] = None input_data_dict["matching_threshold"] = threshold_slider input_data_dict["temporature"] = temporature_slider response_text, response_gd, response_mask, mask_inter = our_chatbot.inference(input_data_dict) response_msks = post_process_masks(history[1][0][0], mask_inter, history[1][0][0], response_mask, loc_inter, interaction_selector) if "" in response_text: response_gd = post_process_gd_response(response_msks[1][0], response_gd) response_msks[1] = list(response_msks[1]) response_msks[1][0] = response_gd[0] response_msks[1] = tuple(response_msks[1]) response_text = post_process_text_response(response_text) history += response2stream(response_text, text) for response_msk in response_msks: if not (response_msk is None): history += response2stream(response_msk, None) return history, None def add_image(history, image): print("LOG. Add Image Function is called.") path_input_img = get_image_name(prefix="tmp_input_img_") cv2.imwrite(path_input_img, image["image"][..., ::-1]) if len(history) > 0: history = [(None, "A new image recieved, I will clear the history conversations.")] else: history = [(None, None)] # just to align with the above one, to determin where the image_path is. history = history + [((path_input_img, ), None)] history = history + [(None, "Let't talk about this image!")] return history def add_interaction_click(history, image, interaction_selector): print("LOG. Add Interaction Function is called.") if interaction_selector.lower() == "box": history = history + [(None, "A more detailed box is specified, lets further talk about the region inside the box.")] elif interaction_selector.lower() == "click": history = history + [(None, "A more detailed click is specified, lets further talk about the region around the click.")] mask = image["mask"][..., :3] * np.array([234, 176, 113]) image_rgb = image["image"][..., ::-1] image_clicked = (image_rgb * 0.6 + mask * 0.4).astype(np.uint8) path_save = get_image_name(prefix="click_img_") cv2.imwrite(path_save, image_clicked) return history def bot(history): yield history def clear_history(history, txt, img): return None, None, None def clear_response(history): for index_conv in range(1, len(history)): # loop until get a text response from our model. conv = history[-index_conv] if not (conv[0] is None): break question = history[-index_conv][0] history = history[:-index_conv] return history, question def upvote_one(history): print("TODO: Implement upvote_one function.") pass def downvote_one(history): print("TODO: Implement downvote_one function.") pass def flag_one(history): print("TODO: Implement flag_one function.") pass #? defined here for later renderring. txt = gr.Textbox( scale=4, show_label=False, placeholder="Enter text and press enter, or upload an image. Append '(with grounding)' if you want to do grounding.", container=False, ) with gr.Blocks() as demo: # Informations title_markdown = (""" # LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models [[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)] [[Demo](http://llava-grounding.xyzou.net:6084)] [[Model](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp)] """) tips_markdown = (""" **Tips for better results** 1. Adjust 'Threshold' according to the results or change your expression and click 'Regenerate' may help get better results. 2. Set temporature to 0.0 get reproducible results. """) tos_markdown = (""" **Terms of use:** By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. """) learn_more_markdown = (""" **License:** The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) models = [ "LLaVA-Grounding-7B", ] interactions = [ "Click", "Box" ] cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Markdown(title_markdown) with gr.Row(): with gr.Column(min_width=300, scale=0.4): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) img = gr.Image( type="numpy", # label="Image", height=220, tool="sketch", interactive=True ) img_upload_btn = gr.Button("Submit Image") with gr.Row(): inter_upload_btn = gr.Button("Submit Interaction") interaction_selector = gr.Dropdown( choices=interactions, value=interactions[0] if len(interactions) > 0 else "", interactive=True, show_label=False, container=False, ) with gr.Row(): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature") threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.05, step=0.1, interactive=True, label="Threshold") # possibly a bug in gradio: https://github.com/gradio-app/gradio/issues/3623 gr.Examples(examples=[ [f"{cur_dir}/examples/meeting.jpg", "Describe the scene in detail. (with grounding)"], [f"{cur_dir}/examples/pizza.jpg", "Describe the scene in detail. (with grounding)"], ], inputs=[img, txt], label="Grounded Description Examples: ") gr.Examples(examples=[ [f"{cur_dir}/examples/cow_motor.jpg", "Where is the object % and what is it doing?"], [f"{cur_dir}/examples/dog_sleep.jpg", "What is the object % doing and why?"], ], inputs=[img, txt], label="Visual Prompt Examples (Please draw clicks or boxes on the woman and dog for the two examples, respectively.): ") with gr.Column(): chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, height=598 # avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))), ) with gr.Row(): with gr.Column(scale=8): txt.render() with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Send") #TODO: Enable these buttons. with gr.Row(): upvote_btn = gr.Button(value="👍 Upvote", interactive=True) downvote_btn = gr.Button(value="👎 Downvote", interactive=True) flag_btn = gr.Button(value="⚠️ Flag", interactive=True) #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) gr.Markdown(tips_markdown) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) if os.path.isfile("gradio_demo/examples/demo_grounding.mp4"): # only online demo gr.Markdown("-----------------------------------") gr.Markdown("## User's Guidance") with gr.Row(): with gr.Column(): gr.Markdown("### Grounded Visual Chat") gr.Video(value="gradio_demo/examples/demo_grounding.mp4") with gr.Column(): gr.Markdown("### Visual Prompt (Click)") gr.Video(value="gradio_demo/examples/demo_inter_click.mp4") with gr.Column(): gr.Markdown("### Visual Prompt (Box)") gr.Video(value="gradio_demo/examples/demo_inter_box.mp4") txt.submit(add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then( bot, chatbot, chatbot, api_name="bot_text_response" ) submit_btn.click(fn=add_text, inputs=[chatbot, txt, img, threshold, temperature, interaction_selector], outputs=[chatbot, txt]).then( bot, chatbot, chatbot, api_name="submit_text" ) img_upload_btn.click(fn=add_image, inputs=[chatbot, img], outputs=[chatbot], api_name="upload_image") inter_upload_btn.click(fn=add_interaction_click, inputs=[chatbot, img, interaction_selector], outputs=[chatbot], api_name="upload_inter") # buttons clear_btn.click(fn=clear_history, inputs=[chatbot, txt, img], outputs=[chatbot, txt, img], api_name="clear_all") regenerate_btn.click(fn=clear_response, inputs=[chatbot], outputs=[chatbot, txt], api_name="clear_last_response").then( add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then( bot, chatbot, chatbot, api_name="regenerate_response" ) upvote_btn.click(fn=upvote_one, inputs=[], outputs=[], api_name="upvote_one") downvote_btn.click(fn=downvote_one, inputs=[], outputs=[], api_name="downvote_one") flag_btn.click(fn=flag_one, inputs=[], outputs=[], api_name="flag_one") demo.queue() if __name__ == "__main__": import argparse argparser = argparse.ArgumentParser() argparser.add_argument("--server_name", default="0.0.0.0", type=str) argparser.add_argument("--port", default=12124, type=str) argparser.add_argument("--model_path", default="", type=str) argparser.add_argument("--path_vision_cfg", default="configs/openseed/openseed_swint_lang_joint_2st_v2_data_end_with_interaction.yaml", type=str) argparser.add_argument("--path_inter_cfg", default="configs/semsam/idino_swint_1_part_data_llm_ref_feat_all_16_det_pretrainv1.yaml", type=str) args = argparser.parse_args() model_path = args.model_path colors = generate_distinct_colors(20) if not os.path.exists("./gradio_demo/tmp_files"): os.makedirs("./gradio_demo/tmp_files") our_chatbot = InferenceDemo(args.model_path, args.path_vision_cfg, args.path_inter_cfg) demo.launch(server_name=args.server_name, server_port=int(args.port)) ================================================ FILE: gradio_demo/__init__.py ================================================ ================================================ FILE: llava/__init__.py ================================================ from .model import LlavaLlamaForCausalLM ================================================ FILE: llava/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" ================================================ FILE: llava/conversation.py ================================================ import dataclasses from enum import auto, Enum from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if 'mmtag' in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'' ret.append([img_str, None]) msg = msg.replace('', '').strip() if len(msg) > 0: ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llava_llama_2 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mpt = Conversation( system="""<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_llava_plain = Conversation( system="", roles=("", ""), messages=( ), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", ) conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v0_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), messages=( ), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", version="v0_mmtag", ) conv_llava_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llava_v1_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", version="v1_mmtag", ) default_conversation = conv_vicuna_v0 conv_templates = { "default": conv_vicuna_v0, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "llama_2": conv_llama_2, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, "v0_mmtag": conv_llava_v0_mmtag, "llava_v1": conv_llava_v1, "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "mpt": conv_mpt, } if __name__ == "__main__": print(default_conversation.get_prompt()) ================================================ FILE: llava/eval/LLaVA_G_Eval.py ================================================ import os import cv2 import json import torch import collections import transformers import numpy as np from llava.model import * from typing import Dict from llava import conversation as conversation_lib from tqdm import tqdm from detectron2.utils.file_io import PathManager from llava.mm_utils import tokenizer_image_token from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN from llava.eval.llava_mapper import COCOInstanceNewBaselineDatasetMapper as LLAVAInstanceNewBaselineDatasetMapper grounding_start="" grounding_end="" SEG_TOKEN="" BOX_TOKEN="#B#" MARKER_TOKEN="#M#" def load_jsonl_file(path_jsonl): import jsonlines data = [] with jsonlines.open(path_jsonl, "r") as reader: for obj in reader: data.append(obj) return data def save_jsonl_file(data, path_save): import jsonlines with jsonlines.open(path_save, "w") as writer: for item in data: writer.write(item) def load_benchmark(image_root, path_benchmark): data = load_jsonl_file(path_benchmark) ret = [] for d in data: image_name = d["image"] image_id = int(image_name.split(".")[0]) image_file = os.path.join(image_root, "COCO_val2014_" + image_name) # conv = d["conversations"] conv = [ { "from": "human", "value": d["text"] }, { "from": "gpt", "value": "Placeholder." } ] conv[0]["value"] = DEFAULT_IMAGE_TOKEN + " " + conv[0]["value"] + " (with grounding)" ret.append( { "file_name": image_file, "image_id": image_id, # "grounding_info": None, "conversations": [[conv, None]], "question_id": d["question_id"] } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conv_prompt = conv.get_prompt() conv_prompt = conv_prompt.split("ASSISTANT: ")[0] + "ASSISTANT:" conversations.append(conv_prompt) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) class Evaluator_MM: def __init__(self, model_path, path_vision_model_cfg=None): model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] print("1. Constructing model...") self.tokenizer, self.model, self.image_processor, self.context_len = self.construct_model( model_path=model_path, model_name=self.model_name, ) print(" Continue...") self.construct_vision_model(path_vision_model_cfg) print("Done.") self.image_processor=self.model.get_vision_tower().image_processor print("2. Loading Parameters...") self.load_parameters(model_path) print("Done.") self.model.eval() conversation_lib.default_conversation = conversation_lib.conv_templates["v1"] self.data_mapper = LLAVAInstanceNewBaselineDatasetMapper(self.cfg_vision_model, False, tokenizer=self.tokenizer, image_processor=self.image_processor, preprocess=preprocess_v1) def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map="auto"): import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def construct_vision_model(self, path_vision_model_cfg): from detectron2.config import LazyConfig from llava.model.openseed import build_model from llava.model.openseed.BaseModel import BaseModel def get_config_from_name(cfg, dataset_name="flickr"): # adjust config according to dataset, flickr by default if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct_train' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'refcoco' in dataset_name or "flickr_val" in dataset_name: cfg.update(cfg['REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg) vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model)) vision_model.eval() self.model.seg_model = vision_model self.model.seg_model.model = self.model.seg_model.model.to(self.model.device) # print("Configuring for Dataset Mapper ...") self.cfg_vision_model = get_config_from_name(self.cfg_vision_model) def load_parameters(self, path_model): print("Loading Whole Model ...") loaded_dict = dict() for model_file in os.listdir(path_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(path_model, model_file), map_location='cpu')) self.model.load_state_dict(loaded_dict, strict=True) @torch.inference_mode() def evaluate_sample(self, input_data, get_box=True, get_mask=False): text, boxes, masks = self.model.forward_eval(input_data) returns = [text,] if get_box: returns.append(boxes) if get_mask: returns.append(masks) return returns class Evaluator_MM_Inter(Evaluator_MM): def __init__(self, model_path, path_vision_model_cfg=None, path_inter_model_cfg=None): self.path_inter_model_cfg = path_inter_model_cfg super().__init__(model_path, path_vision_model_cfg) def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map="auto"): import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def construct_vision_model(self, path_vision_model_cfg): from detectron2.config import LazyConfig from llava.model.openseed import build_model from llava.model.openseed.BaseModel import BaseModel def get_config_from_name(cfg, dataset_name="flickr"): # adjust config according to dataset, flickr by default if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct_train' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'refcoco' in dataset_name or "flickr_val" in dataset_name: cfg.update(cfg['REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg) vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model)) vision_model.eval() self.model.seg_model = vision_model self.model.seg_model.model = self.model.seg_model.model.to(self.model.device) self.cfg_inter_model = LazyConfig.load(self.path_inter_model_cfg) self.model.initialize_interactive_modules(self.cfg_inter_model) self.model.interactive_model.model = self.model.interactive_model.model.to(self.model.device) # print("Configuring for Dataset Mapper ...") self.cfg_vision_model = get_config_from_name(self.cfg_vision_model) @torch.inference_mode() def evaluate_sample(self, input_data): text, boxes, masks, mask_inter = self.model.forward_eval(input_data) return text, boxes, masks, mask_inter def formatting(text, boxes, question_id): def find_start_idxes(sentence, word): window_size = len(word) start_indexes = [] assert len(sentence) > window_size if sentence == window_size: return [0] for start_index in range(len(sentence) - window_size+1): if sentence[start_index: start_index + window_size] == word: start_indexes.append(start_index) return start_indexes def extract_text(sentence): # Use regular expression to find and extract the text and number import re pattern = r"|" cleaned_text = re.sub(pattern, '', sentence) return cleaned_text def multiboxes_to_str(boxes): boxes_text = [] for box in boxes: boxes_text.append(list_to_str(box)) output_string = ";".join(boxes_text) return output_string.replace("];[", ";") def list_to_str(list_): list_str = [str(round(aa, 3)) for aa in list_] return "[" + ",".join(list_str) + "]" def format_sentence(splitted_sentence): joint_sentence = " ".join(splitted_sentence) return joint_sentence text_pure = "" text_boxes = "" boxes_pure = [] number = 0 seg_start_index = find_start_idxes(text, "") if len(seg_start_index) > 0: # text = text[:tail_start_index[0]] subtexts = text.split(" ") for subtext in subtexts: if "" in subtext: # subtext += "" start_idx = find_start_idxes(subtext, "")[0] text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())]) text_boxes = format_sentence([text_boxes, format_sentence(subtext[:start_idx].split())]) text_ = extract_text(subtext[start_idx:]) text_pure = format_sentence([text_pure, format_sentence(text_.split())]) if number >= len(boxes): print("Error, There should be a wrong prediction.") text_boxes = format_sentence([text_boxes, format_sentence(text_.split())]) number += 1 continue text_boxes = format_sentence([text_boxes, format_sentence(text_.split()) + multiboxes_to_str(boxes[number].cpu().tolist())]) boxes_pure.append(multiboxes_to_str(boxes[number].cpu().tolist())) number += 1 else: text_pure = format_sentence([text_pure, format_sentence(subtext.split())]) text_boxes = format_sentence([text_boxes, format_sentence(subtext.split())]) return { "question_id": question_id, "text": text_pure, "text_boxes": text_boxes, "boxes": boxes_pure, } else: return { "question_id": question_id, "text": text, "text_boxes": text, "boxes": [] } def evaluate_(path_benchmarks, dir_image, evaluator, matching_threshold): def unresize_box(box, width, height, size): # ori_size = max(width, height) # ratio = ori_size / size ratio = min(width, height) / max(width, height) if width > height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 1] = box[:, 1] / ratio box[:, 3] = box[:, 3] / ratio elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 0] = box[:, 0] / ratio box[:, 2] = box[:, 2] / ratio return box def filter_empty_box(text, boxes_image): def extract_text(sentence): # Use regular expression to find and extract the text and number import re if " " in sentence: pattern = r"| " cleaned_text = re.sub(pattern, '', sentence) return cleaned_text else: cleaned_text = re.sub(r' \d+', '', sentence) cleaned_text = re.sub(r' ', '', cleaned_text) return cleaned_text has_gd = True if "" in text else False if len(boxes_image) == 0: return text, boxes_image else: if has_gd: sub_texts = text.split(" ") sub_texts_filtered = [] boxes_image_filtered = [] for box_per_gd, text_per_gd in zip(boxes_image, sub_texts): text_per_gd += " " ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0) if len(ind_nonempty_box[0]) < box_per_gd.shape[0]: # empty box encountered if len(ind_nonempty_box[0]) == 0: text_per_gd = " " + " ".join(extract_text(text_per_gd).split()) sub_texts_filtered.append(text_per_gd) # box is desperated continue else: box_per_gd = box_per_gd[ind_nonempty_box] boxes_image_filtered.append(box_per_gd) sub_texts_filtered.append(text_per_gd) else: boxes_image_filtered.append(box_per_gd) sub_texts_filtered.append(text_per_gd) sub_texts_filtered.append(sub_texts[-1]) text_filtered = "".join(sub_texts_filtered) return text_filtered, boxes_image_filtered else: text_filtered = " ".join(extract_text(text).split()) boxes_image_filtered = [] return text_filtered, boxes_image_filtered def debug(image, boxes, prefix): import cv2 # image = cv2.imread(path_image) def transform_str2numpy(boxes_str): boxes_str = boxes_str.replace(";", "];[") boxes_list = [] for box_str in boxes_str.split(";"): box_list = [float(aa) for aa in box_str[1:-1].split(",")] boxes_list.append(box_list) return boxes_list boxes_ = [] for box in boxes: boxes_.extend(transform_str2numpy(box)) boxes = torch.tensor(boxes_) image = image[..., ::-1] image = np.ascontiguousarray(image, dtype=np.uint8) height,width,_ = image.shape for box in boxes: box = (box.cpu() * torch.tensor([width, height, width, height])).int().squeeze() box = box.tolist() image = cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0)) cv2.imwrite(f"{prefix}_debug.jpg", image) datas = load_benchmark(dir_image, path_benchmarks)[:20] #! use first 20 samples for debug. data_mapper = evaluator.data_mapper device = evaluator.model.device outputs = [] for data in tqdm(datas): input_data = data_mapper(data)[0] for key, value in input_data.items(): if isinstance(value, torch.Tensor): input_data[key] = value.to(device) input_data["matching_threshold"] = matching_threshold text, boxes = evaluator.evaluate_sample([input_data]) text, boxes = filter_empty_box(text, boxes) boxes = [unresize_box(bb.detach().cpu(), input_data["width"], input_data["height"], 1024) for bb in boxes] output = formatting(text, boxes, input_data["question_id"]) # from ipdb import set_trace; set_trace() # debug(cv2.imread(input_data["file_name"]), output["boxes"], prefix=str(input_data["question_id"])) outputs.append(output) return outputs def evaluate(args=None): evaluator = Evaluator_MM( model_path=args.model_path, path_vision_model_cfg=args.vision_model_cfg, ) results = evaluate_(args.path_benchmark, dir_image=args.image_root, evaluator=evaluator, matching_threshold=args.matching_threshold) return results if __name__ == "__main__": import argparse args = argparse.ArgumentParser() args.add_argument("--model_path", type=str, default="xx") args.add_argument("--vision_model_cfg", type=str, default="xx") args.add_argument("--matching_threshold", type=float, default=0.2) args.add_argument("--path_benchmark", default="./dataset/qa1000_questions.jsonl") args.add_argument("--image_root", default="./dataset/coco/val2014") args = args.parse_args() results = evaluate(args) path_save = f"./LLaVA_G_{args.path_benchmark.split('/')[-1].split('.')[0]}_t{args.matching_threshold}.jsonl" #! sync print("Writing at: ", path_save) save_jsonl_file(results, path_save) # CUDA_VISIBLE_DEVICES=0 python llava/eval/LLaVA_G_Eval.py --model_path ./checkpoints/llava_stage2_new_joint_seg0.1_data_v3/checkpoint-8000 --vision_model_cfg configs/openseed/openseed_swint_lang_coco_instruct_coco_end_llava_bench.yaml --path_benchmark ./dataset/qa90_questions.jsonl --image_root ./dataset/coco/val2014 --matching_threshold 0.2 ================================================ FILE: llava/eval/eval_gpt_review.py ================================================ import argparse import json import os import openai import tqdm import ray import time NUM_SECONDS_TO_SLEEP = 3 @ray.remote(num_cpus=4) def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( model='gpt-4', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) print('success!') return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') # parser.add_argument('-a', '--answer') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() ray.init() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) review_file = open(f'{args.output}', 'w') js_list = [] handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): # if idx == 1: # break ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: rule = rule_dict['default'] prompt = rule['prompt'] role = rule['role'] content = (f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') js_list.append({ 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1['answer_id'], 'answer2_id': ans2['answer_id'], 'category': category}) idx += 1 handles.append(get_eval.remote(content, args.max_tokens)) # To avoid the rate limit set by OpenAI time.sleep(NUM_SECONDS_TO_SLEEP) reviews = ray.get(handles) for idx, review in enumerate(reviews): scores = parse_score(review) js_list[idx]['content'] = review js_list[idx]['tuple'] = scores review_file.write(json.dumps(js_list[idx]) + '\n') review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_bench.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 openai.api_type = "azure" openai.api_base = "https://xdecoder.openai.azure.com/" openai.api_version = "2023-03-15-preview" os.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032' openai.api_key = os.getenv("OPENAI_API_KEY") def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt4a', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['caption']) category = 'llava_bench_' + json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_visual.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 openai.api_type = "azure" openai.api_base = "https://xdecoder.openai.azure.com/" openai.api_version = "2023-03-15-preview" os.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032' openai.api_key = os.getenv("OPENAI_API_KEY") def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt4a', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['captions']) box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']) if 'answer_id' in ans2 else ans2['question_id'], 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_visual2.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 os.environ['OPENAI_API_KEY']='233c45550c614b72b8f3c9309efecf06' openai.api_type = "azure" openai.api_base = 'https://azureopenaifiahmedeastus.openai.azure.com/' openai.api_version = '2023-03-15-preview' openai.api_key = "233c45550c614b72b8f3c9309efecf06" def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt-4-32k-0314', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['captions']) box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_science_qa.py ================================================ import argparse import json import os import re import random def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--result-file', type=str) parser.add_argument('--output-file', type=str) parser.add_argument('--output-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) predictions = [json.loads(line) for line in open(args.result_file)] predictions = {pred['question_id']: pred for pred in predictions} split_problems = {idx: problems[idx] for idx in split_indices} results = {'correct': [], 'incorrect': []} sqa_results = {} sqa_results['acc'] = None sqa_results['correct'] = None sqa_results['count'] = None sqa_results['results'] = {} sqa_results['outputs'] = {} for prob_id, prob in split_problems.items(): if prob_id not in predictions: continue pred = predictions[prob_id] pred_text = pred['text'] pattern = re.compile(r'The answer is ([A-Z]).') res = pattern.findall(pred_text) if len(res) == 1: answer = res[0] # 'A', 'B', ... else: answer = "FAILED" pred_idx = get_pred_idx(answer, prob['choices'], args.options) analysis = { 'question_id': prob_id, 'parsed_ans': answer, 'ground_truth': args.options[prob['answer']], 'question': pred['prompt'], 'pred': pred_text, 'is_multimodal': '' in pred['prompt'], } sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) sqa_results['outputs'][prob_id] = pred_text if pred_idx == prob['answer']: results['correct'].append(analysis) else: results['incorrect'].append(analysis) correct = len(results['correct']) total = len(results['correct']) + len(results['incorrect']) print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') sqa_results['acc'] = correct / total * 100 sqa_results['correct'] = correct sqa_results['count'] = total with open(args.output_file, 'w') as f: json.dump(results, f, indent=2) with open(args.output_result, 'w') as f: json.dump(sqa_results, f, indent=2) ================================================ FILE: llava/eval/eval_science_qa_gpt4.py ================================================ import argparse import json import os import re import random from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--gpt4-result', type=str) parser.add_argument('--our-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] our_predictions = {pred['question_id']: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] results = defaultdict(lambda: 0) for prob_id, prob in split_problems.items(): if prob_id not in our_predictions: continue if prob_id not in gpt4_predictions: continue our_pred = our_predictions[prob_id]['text'] gpt4_pred = gpt4_predictions[prob_id] pattern = re.compile(r'The answer is ([A-Z]).') our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... else: our_answer = "FAILED" gpt4_res = pattern.findall(gpt4_pred) if len(gpt4_res) == 1: gpt4_answer = gpt4_res[0] # 'A', 'B', ... else: gpt4_answer = "FAILED" our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) if gpt4_answer == 'FAILED': results['gpt4_failed'] += 1 # continue gpt4_pred_idx = our_pred_idx # if our_pred_idx != prob['answer']: # print(our_predictions[prob_id]['prompt']) # print('-----------------') # print(f'LECTURE: {prob["lecture"]}') # print(f'SOLUTION: {prob["solution"]}') # print('=====================') else: # continue pass # gpt4_pred_idx = our_pred_idx if gpt4_pred_idx == prob['answer']: results['correct'] += 1 else: results['incorrect'] += 1 if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: results['correct_upperbound'] += 1 correct = results['correct'] total = results['correct'] + results['incorrect'] print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') ================================================ FILE: llava/eval/eval_science_qa_gpt4_requery.py ================================================ import argparse import json import os import re import random from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--gpt4-result', type=str) parser.add_argument('--requery-result', type=str) parser.add_argument('--our-result', type=str) parser.add_argument('--output-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] our_predictions = {pred['question_id']: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} requery_predictions = [json.loads(line) for line in open(args.requery_result)] requery_predictions = {pred['question_id']: pred for pred in requery_predictions} gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] results = defaultdict(lambda: 0) sqa_results = {} sqa_results['acc'] = None sqa_results['correct'] = None sqa_results['count'] = None sqa_results['results'] = {} sqa_results['outputs'] = {} for prob_id, prob in split_problems.items(): if prob_id not in our_predictions: assert False if prob_id not in gpt4_predictions: assert False our_pred = our_predictions[prob_id]['text'] gpt4_pred = gpt4_predictions[prob_id] if prob_id not in requery_predictions: results['missing_requery'] += 1 requery_pred = "MISSING" else: requery_pred = requery_predictions[prob_id]['text'] pattern = re.compile(r'The answer is ([A-Z]).') our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... else: our_answer = "FAILED" requery_res = pattern.findall(requery_pred) if len(requery_res) == 1: requery_answer = requery_res[0] # 'A', 'B', ... else: requery_answer = "FAILED" gpt4_res = pattern.findall(gpt4_pred) if len(gpt4_res) == 1: gpt4_answer = gpt4_res[0] # 'A', 'B', ... else: gpt4_answer = "FAILED" our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) results['total'] += 1 if gpt4_answer == 'FAILED': results['gpt4_failed'] += 1 if gpt4_pred_idx == prob['answer']: results['gpt4_correct'] += 1 if our_pred_idx == prob['answer']: results['gpt4_ourvisual_correct'] += 1 elif gpt4_pred_idx == prob['answer']: results['gpt4_correct'] += 1 results['gpt4_ourvisual_correct'] += 1 if our_pred_idx == prob['answer']: results['our_correct'] += 1 if requery_answer == 'FAILED': sqa_results['results'][prob_id] = our_pred_idx if our_pred_idx == prob['answer']: results['requery_correct'] += 1 else: sqa_results['results'][prob_id] = requery_pred_idx if requery_pred_idx == prob['answer']: results['requery_correct'] += 1 else: print(f""" Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} Our ({our_answer}): {our_pred} GPT-4 ({gpt4_answer}): {gpt4_pred} Requery ({requery_answer}): {requery_pred} print("=====================================") """) if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: results['correct_upperbound'] += 1 total = results['total'] print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') sqa_results['acc'] = results["requery_correct"] / total * 100 sqa_results['correct'] = results["requery_correct"] sqa_results['count'] = total with open(args.output_result, 'w') as f: json.dump(sqa_results, f, indent=2) ================================================ FILE: llava/eval/generate_webpage_data_from_table.py ================================================ """Generate json file for webpage.""" import json import os import re # models = ['llama', 'alpaca', 'gpt35', 'bard'] models = ['vicuna'] def read_jsonl(path: str, key: str=None): data = [] with open(os.path.expanduser(path)) as f: for line in f: if not line: continue data.append(json.loads(line)) if key is not None: data.sort(key=lambda x: x[key]) data = {item[key]: item for item in data} return data def trim_hanging_lines(s: str, n: int) -> str: s = s.strip() for _ in range(n): s = s.split('\n', 1)[1].strip() return s if __name__ == '__main__': questions = read_jsonl('table/question.jsonl', key='question_id') # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') records = [] for qid in questions.keys(): r = { 'id': qid, 'category': questions[qid]['category'], 'question': questions[qid]['text'], 'answers': { # 'alpaca': alpaca_answers[qid]['text'], # 'llama': llama_answers[qid]['text'], # 'bard': bard_answers[qid]['text'], # 'gpt35': gpt35_answers[qid]['text'], 'vicuna': vicuna_answers[qid]['text'], 'ours': ours_answers[qid]['text'], }, 'evaluations': { # 'alpaca': review_alpaca[qid]['text'], # 'llama': review_llama[qid]['text'], # 'bard': review_bard[qid]['text'], 'vicuna': review_vicuna[qid]['content'], # 'gpt35': review_gpt35[qid]['text'], }, 'scores': { 'vicuna': review_vicuna[qid]['tuple'], # 'alpaca': review_alpaca[qid]['score'], # 'llama': review_llama[qid]['score'], # 'bard': review_bard[qid]['score'], # 'gpt35': review_gpt35[qid]['score'], }, } # cleanup data cleaned_evals = {} for k, v in r['evaluations'].items(): v = v.strip() lines = v.split('\n') # trim the first line if it's a pair of numbers if re.match(r'\d+[, ]+\d+', lines[0]): lines = lines[1:] v = '\n'.join(lines) cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') r['evaluations'] = cleaned_evals records.append(r) # Reorder the records, this is optional for r in records: if r['id'] <= 20: r['id'] += 60 else: r['id'] -= 20 for r in records: if r['id'] <= 50: r['id'] += 10 elif 50 < r['id'] <= 60: r['id'] -= 50 for r in records: if r['id'] == 7: r['id'] = 1 elif r['id'] < 7: r['id'] += 1 records.sort(key=lambda x: x['id']) # Write to file with open('webpage/data.json', 'w') as f: json.dump({'questions': records, 'models': models}, f, indent=2) ================================================ FILE: llava/eval/llava_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, is_multimodal ): if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if False: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, image_processor, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = image_processor self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,image_processor=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "image_processor": image_processor, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) # selected_conversation, grounding_list = dataset_dict['conversations'][rd] # dataset_dict['conversation'] = [selected_conversation] selected_conversation = [aa[0] for aa in dataset_dict['conversations']] dataset_dict['conversation'] = selected_conversation sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), True) #! Debug here # sources = copy.deepcopy(dataset_dict['conversation']) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer num_segs = 1 # sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= if "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor if grounding_list is None: dataset_dict['grounding']=False grounding_mask=[False for _ in range(num_segs)] dataset_dict['grounding_mask']=grounding_mask else: grounding_mask=[True if g is not None else False for g in grounding_list] dataset_dict['grounding_mask']=grounding_mask new_grounding_list=[g for g in grounding_list if g is not None] if sum(grounding_mask)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: # assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(new_grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances else: dataset_dict['grounding'] = False grounding_mask = [False for _ in range(num_segs)] dataset_dict['grounding_mask'] = grounding_mask return [dataset_dict] ================================================ FILE: llava/eval/model_qa.py ================================================ import argparse from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria import torch import os import json from tqdm import tqdm import shortuuid from llava.conversation import default_conversation from llava.utils import disable_torch_init # new stopping implementation class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False @torch.inference_mode() def eval_model(model_name, questions_file, answers_file): # Model disable_torch_init() model_name = os.path.expanduser(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() ques_file = open(os.path.expanduser(questions_file), "r") ans_file = open(os.path.expanduser(answers_file), "w") for i, line in enumerate(tqdm(ques_file)): idx = json.loads(line)["question_id"] qs = json.loads(line)["text"] cat = json.loads(line)["category"] conv = default_conversation.copy() conv.append_message(conv.roles[0], qs) prompt = conv.get_prompt() inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) output_ids = model.generate( input_ids, do_sample=True, use_cache=True, temperature=0.7, max_new_tokens=1024, stopping_criteria=[stopping_criteria]) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] try: index = outputs.index(conv.sep, len(prompt)) except ValueError: outputs += conv.sep index = outputs.index(conv.sep, len(prompt)) outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--question-file", type=str, default="tables/question.jsonl") parser.add_argument("--answers-file", type=str, default="answer.jsonl") args = parser.parse_args() eval_model(args.model_name, args.question_file, args.answers_file) ================================================ FILE: llava/eval/model_vqa.py ================================================ import argparse import torch import os import json from tqdm import tqdm import shortuuid from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] def eval_model(args): # Model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) ans_file = open(answers_file, "w") for line in tqdm(questions): idx = line["question_id"] image_file = line["image"] qs = line["text"] cur_prompt = qs if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() try: image = Image.open(os.path.join(args.image_folder, "COCO_val2014_"+image_file)) except Exception: image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) # import pdb; pdb.set_trace() with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor.unsqueeze(0).half().cuda(), do_sample=True, temperature=args.temperature, top_p=args.top_p, num_beams=args.num_beams, # no_repeat_ngram_size=3, max_new_tokens=2048, use_cache=True) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "prompt": cur_prompt, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.jsonl") parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--conv-mode", type=str, default="llava_v1") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--num_beams", type=int, default=1) args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/model_vqa_science.py ================================================ import argparse import torch import os import json from tqdm import tqdm import shortuuid from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] def eval_model(args): # Model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) questions = json.load(open(os.path.expanduser(args.question_file), "r")) questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) ans_file = open(answers_file, "w") for i, line in enumerate(tqdm(questions)): idx = line["id"] question = line['conversations'][0] gt_ans = line["conversations"][1] qs = question['value'].replace('', '').strip() cur_prompt = qs if 'image' in line: image_file = line["image"] image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] images = image_tensor.unsqueeze(0).half().cuda() if getattr(model.config, 'mm_use_im_start_end', False): qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs cur_prompt = '' + '\n' + cur_prompt else: images = None conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=images, do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() # prompt for answer if args.answer_prompter: outputs_reasoning = outputs input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() with torch.inference_mode(): output_ids = model.generate( input_ids, images=images, do_sample=True, temperature=0.2, max_new_tokens=64, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() outputs = outputs_reasoning + '\n The answer is ' + outputs ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "prompt": cur_prompt, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.json") parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--conv-mode", type=str, default="llava_v0") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) parser.add_argument("--answer-prompter", action="store_true") args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/qa_baseline_gpt35.py ================================================ """Generate answers with GPT-3.5""" # Note: you need to be using OpenAI Python v0.27.0 for the code below to work import argparse import json import os import time import concurrent.futures import openai import tqdm import shortuuid MODEL = 'gpt-3.5-turbo' MODEL_ID = 'gpt-3.5-turbo:20230327' def get_answer(question_id: int, question: str, max_tokens: int): ans = { 'answer_id': shortuuid.uuid(), 'question_id': question_id, 'model_id': MODEL_ID, } for _ in range(3): try: response = openai.ChatCompletion.create( model=MODEL, messages=[{ 'role': 'system', 'content': 'You are a helpful assistant.' }, { 'role': 'user', 'content': question, }], max_tokens=max_tokens, ) ans['text'] = response['choices'][0]['message']['content'] return ans except Exception as e: print('[ERROR]', e) ans['text'] = '#ERROR#' time.sleep(1) return ans if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT answer generation.') parser.add_argument('-q', '--question') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() questions_dict = {} with open(os.path.expanduser(args.question)) as f: for line in f: if not line: continue q = json.loads(line) questions_dict[q['question_id']] = q['text'] answers = [] with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [] for qid, question in questions_dict.items(): future = executor.submit(get_answer, qid, question, args.max_tokens) futures.append(future) for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): answers.append(future.result()) answers.sort(key=lambda x: x['question_id']) with open(os.path.expanduser(args.output), 'w') as f: table = [json.dumps(ans) for ans in answers] f.write('\n'.join(table)) ================================================ FILE: llava/eval/run_llava.py ================================================ import argparse import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import requests from PIL import Image from io import BytesIO def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def eval_model(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) qs = args.query if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs if 'llama-2' in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() image = load_image(args.image_file) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() print(outputs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/summarize_gpt_review.py ================================================ import json import os from collections import defaultdict import numpy as np import argparse def parse_args(): parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-d', '--dir', default=None) parser.add_argument('-f', '--files', nargs='*', default=None) parser.add_argument('-i', '--ignore', nargs='*', default=None) return parser.parse_args() if __name__ == '__main__': args = parse_args() if args.ignore is not None: args.ignore = [int(x) for x in args.ignore] if args.files is not None and len(args.files) > 0: review_files = args.files else: review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] for review_file in sorted(review_files): config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') scores = defaultdict(list) print(config) with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: for review_str in f: review = json.loads(review_str) if args.ignore is not None and review['question_id'] in args.ignore: continue if 'category' in review: scores[review['category']].append(review['tuple']) scores['all'].append(review['tuple']) else: if 'tuple' in review: scores['all'].append(review['tuple']) else: scores['all'].append(review['score']) for k, v in sorted(scores.items()): stats = np.asarray(v).mean(0).tolist() stats = [round(x, 3) for x in stats] # print(k, stats, round(stats[1]/stats[0]*100, 1)) print(k, round(stats[1]/stats[0]*100, 1)) print('=================================') ================================================ FILE: llava/eval/webpage/index.html ================================================ Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots 🏔️ Vicuna Evaluation Examples Demo Blog Github Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots Category Question keyboard_arrow_left keyboard_arrow_right */10 */10 Assistant #1 Assistant #2 (Vicuna, our model) GPT-4 Evaluation This website is co-authored with GPT-4. ================================================ FILE: llava/eval/webpage/script.js ================================================ // Description: Script for the evaluation webpage. let currentQuestionIndex = 1; // Store the model name mapping for later use. modelNameMapping = { "gpt35": "ChatGPT-3.5", "gpt4": "GPT-4", "alpaca": "Alpaca-13b", "vicuna": "Vicuna-13b", "llama": "LLaMA-13b", "bard": "Bard", }; modelFigureMapping = { "vicuna": "figures/vicuna.jpeg", // Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg "gpt35": "figures/chatgpt.svg", // Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/ "bard": "figures/bard.jpg", // Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html "alpaca": "figures/alpaca.png", // Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg "llama": "figures/llama.jpg", } // Store the question data in a mapping for later use. questionMapping = {}; // Store the question ids in a mapping for later use. categoryMapping = {}; // Store the number of questions for later use. questionsCount = 0; function text2Markdown(text) { // Normalize the text for markdown rendering. text = text.trim().replaceAll('\n\n', '\n').replaceAll('\n', '\n\n'); return marked.parse(text); } function capitalizeFirstChar(str) { if (!str || str.length === 0) { return str; } return str.charAt(0).toUpperCase() + str.slice(1); } function updateQuestionSelect(question_id) { const select = document.getElementById('question-select'); // Clear the question select. select.innerHTML = ''; // Populate the question select. category = questionMapping[question_id].category; categoryMapping[category].forEach(question_id => { const question = questionMapping[question_id]; const option = document.createElement('option'); option.value = question_id; option.textContent = 'Q' + question_id.toString() + ': ' + question.question; select.appendChild(option); }); select.value = question_id; } function updateModelSelect() { const select = document.getElementById('model-select'); img_path = modelFigureMapping[select.value]; document.getElementById('other-model-figure').src = img_path; } function populateModels(models) { const select = document.getElementById('model-select'); models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = modelNameMapping[model]; select.appendChild(option); }); updateModelSelect(); } function populateQuestions(questions) { const category_select = document.getElementById('category-select'); questionsCount = questions.length; questions.forEach(question => { const option = document.createElement('option'); // Store the question data in a mapping for later use. questionMapping[question.id] = { category: question.category, question: question.question, answers: question.answers, evaluations: question.evaluations, scores: question.scores, }; // Store the question id in the category mapping. if (question.category in categoryMapping) { categoryMapping[question.category].push(question.id); } else { categoryMapping[question.category] = [question.id]; const category_option = document.createElement('option'); category_option.value = question.category; category_option.textContent = capitalizeFirstChar(question.category); category_select.appendChild(category_option); } }); // Set the default category. updateQuestionSelect(currentQuestionIndex); } function displayQuestion(index) { const question = questionMapping[index].question; document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question); displayAnswers(index); } function displayAnswers(index) { const question = questionMapping[index]; const otherModel = document.getElementById('model-select').value; // render the answers with markdown document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]); document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna); // Display evaluation score = question.scores[otherModel]; score_text = modelNameMapping[otherModel] + " " + score[0] + "/10, Vicuna-13b " + score[1] + "/10"; document.getElementById('evaluation-header').textContent = "GPT-4 Evaluation" + " (Score: " + score_text + ")"; document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]); // Update model names let assistant1_title = "Assistant #1"; // (" + modelNameMapping[otherModel] + ")"; let assistant2_title = "Assistant #2 (Vicuna-13b, our model)"; // Update scores/labels. let assistant1_score_label = score[0].toString() + '/10'; let assistant2_score_label = score[1].toString() + '/10'; const colorRed ='#fa9'; // '#eb978d'; // const colorGreen = '#c9f2c9'; const colorBlue = '#8ef'; // '#71dbf9'; const colorYellow = '#fe7'; // '#fada57'; let otherModelHeaderColor = ''; let ourModelHeaderColor = ''; // Update the winner. if (score[0] == score[1]) { assistant1_title = '🏆 ' + assistant1_title; assistant1_score_label = '🏆 ' + assistant1_score_label; assistant2_title = '🏆 ' + assistant2_title; assistant2_score_label = '🏆 ' + assistant2_score_label; otherModelHeaderColor = colorYellow; ourModelHeaderColor = colorYellow; } else if (score[0] > score[1]) { assistant1_title = '🏆 ' + assistant1_title; assistant1_score_label = '🏆 ' + assistant1_score_label; otherModelHeaderColor = colorBlue; ourModelHeaderColor = colorRed; } else if (score[0] < score[1]) { assistant2_title = '🏆 ' + assistant2_title; assistant2_score_label = '🏆 ' + assistant2_score_label; otherModelHeaderColor = colorRed; ourModelHeaderColor = colorBlue; } document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor; document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor; document.getElementById('other-model-header').textContent = assistant1_title; document.getElementById('our-model-header').textContent = assistant2_title; document.getElementById('other-score-label').textContent = assistant1_score_label; document.getElementById('our-score-label').textContent = assistant2_score_label; // Update expand buttons visibility for both cards after displaying answers // Reset the expanded state and update expand buttons visibility for both cards after displaying answers document.querySelectorAll('.expandable-card').forEach(card => { card.classList.remove('expanded'); updateExpandButtonVisibility(card); const expandBtn = card.querySelector('.expand-btn'); expandBtn.innerHTML = 'keyboard_arrow_down Show more'; // .textContent = 'Show more'; }); } document.getElementById('question-select').addEventListener('change', e => { currentQuestionIndex = parseInt(e.target.value); displayQuestion(currentQuestionIndex); }); document.getElementById('category-select').addEventListener('change', e => { let currentCategory = e.target.value; const questionIds = categoryMapping[currentCategory]; currentQuestionIndex = questionIds[0]; updateQuestionSelect(currentQuestionIndex); displayQuestion(currentQuestionIndex); }); // Update expand buttons whenever the model is changed document.getElementById('model-select').addEventListener('change', () => { displayAnswers(currentQuestionIndex); document.querySelectorAll('.expandable-card').forEach(card => { updateExpandButtonVisibility(card); }); updateModelSelect(); }); function switchQuestionAndCategory() { document.getElementById('question-select').value = currentQuestionIndex; old_category = document.getElementById('category-select').value; new_category = questionMapping[currentQuestionIndex].category; if (old_category != new_category) { document.getElementById('category-select').value = new_category; updateQuestionSelect(currentQuestionIndex); } displayQuestion(currentQuestionIndex); } document.getElementById('prev-question').addEventListener('click', () => { // Question index starts from 1. currentQuestionIndex = Math.max(1, currentQuestionIndex - 1); switchQuestionAndCategory(); }); document.getElementById('next-question').addEventListener('click', () => { // Question index starts from 1. currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1); switchQuestionAndCategory(); }); function updateExpandButtonVisibility(card) { const cardTextContainer = card.querySelector('.card-text-container'); const expandBtn = card.querySelector('.expand-btn'); if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) { expandBtn.style.display = 'flex'; } else { expandBtn.style.display = 'none'; card.classList.add('expanded'); } } document.querySelectorAll('.expand-btn').forEach(btn => { btn.addEventListener('click', e => { const card = e.target.closest('.expandable-card'); card.classList.toggle('expanded'); const more = 'keyboard_arrow_down Show more'; const less = 'keyboard_arrow_up Show less'; e.target.innerHTML = card.classList.contains('expanded') ? less : more; }); }); ================================================ FILE: llava/eval/webpage/styles.css ================================================ body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f8f9fa; } .navbar-dark .navbar-nav .nav-link { color: #f1cf68; font-size: 1.1rem; padding: 0.5rem 0.6rem; } .card-header { font-weight: bold; } .card { box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); transition: 0.3s; } .card:hover { box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); } button { transition: background-color 0.3s; } button:hover { background-color: #007bff; } @media (max-width: 767px) { .form-row .form-group { margin-bottom: 10px; } } /* Extra styles */ .expandable-card .card-text-container { max-height: 200px; overflow-y: hidden; position: relative; } .expandable-card.expanded .card-text-container { max-height: none; } .expand-btn { position: relative; display: none; background-color: rgba(255, 255, 255, 0.8); color: #510c75; border-color: transparent; } .expand-btn:hover { background-color: rgba(200, 200, 200, 0.8); text-decoration: none; border-color: transparent; color: #510c75; } .expand-btn:focus { outline: none; text-decoration: none; } .expandable-card:not(.expanded) .card-text-container:after { content: ""; position: absolute; bottom: 0; left: 0; width: 100%; height: 90px; background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); } .expandable-card:not(.expanded) .expand-btn { margin-top: -40px; } .card-body { padding-bottom: 5px; } .vertical-flex-layout { justify-content: center; align-items: center; height: 100%; display: flex; flex-direction: column; gap: 5px; } .figure-img { max-width: 100%; height: auto; } .adjustable-font-size { font-size: calc(0.5rem + 2vw); } ================================================ FILE: llava/mm_utils.py ================================================ from PIL import Image from io import BytesIO import base64 import torch from transformers import StoppingCriteria from llava.constants import IMAGE_TOKEN_INDEX def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def process_images(images, image_processor, model_cfg): return image_processor(images, return_tensors='pt')['pixel_values'] def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def tokenizer_image_token_inter(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0]:] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False ================================================ FILE: llava/model/__init__.py ================================================ from .language_model.llava_llama_gd import LlavaLlamaForCausalLM,LlavaLlamaForCausalLM_gd,LlavaLlamaForCausalLM_joint,LlavaLlamaForCausalLM_joint_2st, LlavaConfig\ ,LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig ================================================ FILE: llava/model/apply_delta.py ================================================ """ Usage: python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava import LlavaLlamaForCausalLM def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading delta") delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] += bparam print("Saving target model") delta.save_pretrained(target_model_path) delta_tokenizer.save_pretrained(target_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) args = parser.parse_args() apply_delta(args.base_model_path, args.target_model_path, args.delta_path) ================================================ FILE: llava/model/builder.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch from llava.model import * from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len ================================================ FILE: llava/model/consolidate.py ================================================ """ Usage: python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate """ import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model import * from llava.model.utils import auto_upgrade def consolidate_ckpt(src_path, dst_path): print("Loading model") auto_upgrade(src_path) src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) src_model.save_pretrained(dst_path) src_tokenizer.save_pretrained(dst_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True) parser.add_argument("--dst", type=str, required=True) args = parser.parse_args() consolidate_ckpt(args.src, args.dst) ================================================ FILE: llava/model/language_model/llava_llama.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import detectron2.utils.comm as comm from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) print(f"rank: {comm.get_rank()}",1) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) print(f"rank: {comm.get_rank()}",2) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) print(f"rank: {comm.get_rank()}",2) if not return_dict: print(f"rank: {comm.get_rank()}", 3) output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print(f"rank: {comm.get_rank()}",4) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) ================================================ FILE: llava/model/language_model/llava_llama_gd.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union IGNORE_INDEX=-100 import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM, LlavaMetaForCausalLM_gd,LlavaMetaForCausalLM_gd_interactive import transformers # @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" # tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances,tokenizer): input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :tokenizer.model_max_length] labels = labels[:, :tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) if 'image_clip' in instances[0]: images = [instance['image_clip'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs class LlavaLlamaForCausalLM_gd(LlamaForCausalLM, LlavaMetaForCausalLM_gd): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() if 'refcoco' in batched_inputs: if 'vg' in batched_inputs: llava_inputs = collator(batched_inputs['vg']+batched_inputs['refcoco'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) else: llava_inputs = collator( batched_inputs['refcoco'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) elif 'coco' in batched_inputs: llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) else: llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) ground_idx_coco=[] ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] if 'refcoco' in seg_inputs: if 'vg' in seg_inputs: vg_len=len(seg_inputs['vg']) ground_idx_flickr = ground_idx[:vg_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[:vg_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['vg_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) flickr_len = len(seg_inputs['refcoco']) ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['refcoco_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) # seg_inputs['flickr']=seg_inputs['refcoco'] else: flickr_len=len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] zero_mask = [0 if len(idx) == 0 else 1 for idx in ground_idx] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr=ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr=padded_ground_idx_flickr!=-1 padded_ground_idx_flickr[padded_ground_idx_flickr==-1]=0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr=hidden_states[:flickr_len] ground_hs_flickr=torch.gather(hidden_states_flickr,1,padded_ground_idx_flickr[...,None].repeat(1,1,hidden_states_flickr.shape[-1])) seg_inputs['flickr_text_embeddings']=(ground_hs_flickr,padded_mask_flickr) ##########coco ground_idx_coco = ground_idx[flickr_len:] if len(ground_idx_coco)>0: for i,(idx,data) in enumerate(zip(ground_idx_coco,seg_inputs['coco'])): mask=data['grounding_mask'] ground_idx_coco[i]=idx[mask[:len(idx)]] padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco, batch_first=True, padding_value=-1) padded_mask_coco = padded_ground_idx_coco != -1 padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0 hidden_states = outputs[0] hidden_states_coco = hidden_states[flickr_len:] ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1, hidden_states_coco.shape[ -1])) seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco) ground_loss=self.seg_model(seg_inputs) if self.seg_model.model.coco_only and len(ground_idx_coco)>0: logits = self.lm_head(hidden_states_coco) else: logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs def forward_eval(self, inputs): collator=DataCollatorForSupervisedDataset() llava_inputs=collator(inputs,tokenizer=inputs[0]['tokenizer']) llava_inputs['seg_inputs']=inputs return self.forward_inner_eval(**llava_inputs) def forward_inner_eval( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True)[0] if len(seg_hidden_states)==0: return output_text, [], [] seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text, predicted_boxes, predicted_masks def auto_regressive_generate(self, attention_mask, past_key_values, inputs_embeds, output_attentions, tokenizer, return_dict, temporature=0.0 ): ######## # llm_inputs['obj_num'] = False seg_token = tokenizer.encode("")[1] seg_token_list = [] output_ids = [] output_logits = [] length = inputs_embeds.shape[1] for i in range(1000): # import pdb;pdb.set_trace() if i == 0: results = self.model( input_ids=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) else: attention_mask = cur_hidden.new_ones( 1, past_key_values[0][0].shape[-2] + 1, device="cuda") # print("Attention mask shape: ", attention_mask.shape) results = self.model( input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device), attention_mask=attention_mask, past_key_values=past_key_values, # inputs_embeds=cur_hidden, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) cur_hidden = results.hidden_states[-1][:, -1:] # last layer last token logits = self.lm_head(results[0]) cur_logits = logits[0][-1] cur_id = int(torch.argmax(cur_logits)) if temporature < 1e-4: cur_id = int(torch.argmax(cur_logits)) else: probs = torch.softmax(cur_logits / temporature, dim=-1) cur_id = int(torch.multinomial(probs, num_samples=1)) past_key_values = results.past_key_values length += 1 if cur_id==seg_token: seg_token_list.append(cur_hidden) output_ids.append(cur_id) output_logits.append(cur_logits) if tokenizer.decode(output_ids).find("")!=-1: break return output_ids,seg_token_list class LlavaLlamaForCausalLM_joint(LlavaLlamaForCausalLM_gd): def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() assert 'refcoco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs for data in batched_inputs['llava']: data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['flickr']+batched_inputs['refcoco']+batched_inputs['llava'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) ground_idx_coco=[] # if 'refcoco' in seg_inputs: flickr_len = len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) # seg_inputs['flickr']=seg_inputs['refcoco'] # else: ################################################# ################################################# refcoco_len=len(seg_inputs['refcoco']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr ground_idx_refcoco=ground_idx[flickr_len:flickr_len+refcoco_len] padded_ground_idx_refcoco = torch.nn.utils.rnn.pad_sequence(ground_idx_refcoco, batch_first=True, padding_value=-1) padded_mask_refcoco=padded_ground_idx_refcoco!=-1 padded_ground_idx_refcoco[padded_ground_idx_refcoco==-1]=0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] # hidden_states = outputs[0] hidden_states_refcoco=hidden_states[flickr_len:flickr_len+refcoco_len] ground_hs_refcoco=torch.gather(hidden_states_refcoco,1,padded_ground_idx_refcoco[...,None].repeat(1,1,hidden_states_refcoco.shape[-1])) seg_inputs['refcoco_text_embeddings']=(ground_hs_refcoco,padded_mask_refcoco) ground_loss=self.seg_model(seg_inputs) # if self.seg_model.model.coco_only and len(ground_idx_coco)>0: # logits = self.lm_head(hidden_states_coco) # else: logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LlavaLlamaForCausalLM_joint_2st(LlavaLlamaForCausalLM_gd): def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs for data in batched_inputs['llava']: data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['flickr']+batched_inputs['coco']+batched_inputs['llava'], tokenizer=batched_inputs['coco'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) flickr_len = len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] if self.seg_model.model.detach_seg: hidden_states = outputs[0].detach() else: hidden_states = outputs[0] hidden_states_flickr = hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) ##########coco coco_len = len(seg_inputs['coco']) ground_idx_coco = ground_idx[flickr_len:flickr_len+coco_len] if len(ground_idx_coco) > 0: for i, (idx, data) in enumerate(zip(ground_idx_coco, seg_inputs['coco'])): mask = data['grounding_mask'] ground_idx_coco[i] = idx[mask[:len(idx)]] padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco, batch_first=True, padding_value=-1) padded_mask_coco = padded_ground_idx_coco != -1 padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0 # hidden_states = outputs[0] hidden_states_coco = hidden_states[flickr_len:flickr_len+coco_len] ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1, hidden_states_coco.shape[ -1])) seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco) ground_loss = self.seg_model(seg_inputs) hidden_states_ = outputs[0] if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: logits = self.lm_head(hidden_states_[flickr_len:]) else: logits = self.lm_head(hidden_states_) ############################################################ loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) ignore_list=[f'_{i}' for i in range(1,10)] ignore_list.append('interm') for key in list(ground_loss.keys()): if not key.endswith('_0') and key!='llava' and key !='loss_total': ground_loss.pop(key) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr(LlamaForCausalLM, LlavaMetaForCausalLM_gd_interactive): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() # assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs and 'interactive' in batched_inputs # for data in batched_inputs['llava']: # data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['interactive'], tokenizer=batched_inputs['interactive'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs res1= self.forward_inner(**llava_inputs) loss_dict1=res1.loss prefix1='coco.' res1.loss=res1['loss']={prefix1+k:v for k,v in loss_dict1.items()} if 'interactiveref' in batched_inputs: llava_inputs = collator( batched_inputs['interactiveref'], tokenizer=batched_inputs['interactive'][0]['tokenizer']) batched_inputs['interactive']=batched_inputs['interactiveref'] llava_inputs['seg_inputs']=batched_inputs res2= self.forward_inner(**llava_inputs) loss_dict2=res2.loss prefix2='refcoco.' res2.loss=res2['loss']={prefix2+k:v for k,v in loss_dict2.items()} res1.loss.update(res2.loss) res1.loss['loss_total']=res1.loss['coco.loss_total']+res1.loss['refcoco.loss_total'] else: res1.loss['loss_total'] = res1.loss['coco.loss_total'] return res1 def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict obj_feats,inter_losses=self.interactive_model.float().forward(seg_inputs['interactive'],detach=False) obj_feats=[obj_feats[i][seg_inputs['interactive'][i]['grounding_index']][None] for i in range(len(obj_feats))] num_it=len(seg_inputs['interactive']) _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images,obj_feats=obj_feats,num_it=num_it) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states_ = outputs[0] logits = self.lm_head(hidden_states_) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() # if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: # shift_labels = labels[..., 1:][flickr_len:].contiguous() # else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss=dict() ground_loss['llava']=loss # for k,v in inter_losses.items(): # print(v.dtype) inter_losses={k:inter_losses[k].to(float) for k in inter_losses.keys()} ground_loss.update(inter_losses) # import pdb;pdb.set_trace() ground_loss['loss_total']=sum(ground_loss.values()) ignore_list=[f'_{i}' for i in range(1,10)] ignore_list.append('interm') for key in list(ground_loss.keys()): if not key.endswith('_0') and key!='llava' and key !='loss_total': ground_loss.pop(key) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def forward_eval(self, batched_inputs): if not (batched_inputs[0]["points"] is None): print("Get Interactive Data.") collator=DataCollatorForSupervisedDataset() llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs if "temporature" in batched_inputs[0].keys(): llava_inputs["temporature"] = batched_inputs[0]["temporature"] else: llava_inputs["temporature"] = 0 return self.forward_inner_eval_interactive(**llava_inputs) else: print("Do not Get Interactive Data.") collator=DataCollatorForSupervisedDataset() llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs if "temporature" in batched_inputs[0].keys(): llava_inputs["temporature"] = batched_inputs[0]["temporature"] else: llava_inputs["temporature"] = 0 return self.forward_inner_eval(**llava_inputs) def forward_inner_eval( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, temporature=0 ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal_NoInter(input_ids, attention_mask, past_key_values, labels, images) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict, temporature) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True)[0] if len(seg_hidden_states)==0: return output_text, [], [], None seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text, predicted_boxes, predicted_masks, None def forward_inner_eval_interactive( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, temporature=0 ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict #! extra interaction part boxes = seg_inputs[0]['points'] seg_inputs[0]['targets'] = [dict()] seg_inputs[0]['targets'][0]['points'] = boxes if seg_inputs[0]['mode_inter'].lower() == "click": seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([0.0]) elif seg_inputs[0]['mode_inter'].lower() == "box": seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([1.0]) seg_inputs[0]['targets'][0]['is_part'] = [0] inter_masks, _, obj_feats =self.interactive_model.forward(seg_inputs) num_it=len(seg_inputs) # _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, obj_feats=obj_feats,num_it=num_it) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict, temporature) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True) if len(seg_hidden_states)==0: return output_text[0], [], None, inter_masks seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text[0], predicted_boxes, predicted_masks, inter_masks def auto_regressive_generate(self, attention_mask, past_key_values, inputs_embeds, output_attentions, tokenizer, return_dict, temporature=0.0 ): ######## # llm_inputs['obj_num'] = False seg_token = tokenizer.encode("")[1] seg_token_list = [] output_ids = [] output_logits = [] length = inputs_embeds.shape[1] for i in range(1000): # import pdb;pdb.set_trace() if i == 0: results = self.model( input_ids=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) else: attention_mask = cur_hidden.new_ones( 1, past_key_values[0][0].shape[-2] + 1, device="cuda") # print("Attention mask shape: ", attention_mask.shape) results = self.model( input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device), attention_mask=attention_mask, past_key_values=past_key_values, # inputs_embeds=cur_hidden, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) cur_hidden = results.hidden_states[-1][:, -1:] # last layer last token logits = self.lm_head(results[0]) cur_logits = logits[0][-1] cur_id = int(torch.argmax(cur_logits)) if temporature < 1e-4: cur_id = int(torch.argmax(cur_logits)) else: probs = torch.softmax(cur_logits / temporature, dim=-1) cur_id = int(torch.multinomial(probs, num_samples=1)) past_key_values = results.past_key_values length += 1 if cur_id==seg_token: seg_token_list.append(cur_hidden) output_ids.append(cur_id) output_logits.append(cur_logits) if tokenizer.decode(output_ids).find("")!=-1: break return output_ids,seg_token_list AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_gd) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr) ================================================ FILE: llava/model/language_model/llava_mpt.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple import warnings import torch import torch.nn.functional as F import math from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaMPTConfig(MPTConfig): model_type = "llava_mpt" class LlavaMPTModel(LlavaMetaModel, MPTModel): config_class = LlavaMPTConfig def __init__(self, config: MPTConfig): config.hidden_size = config.d_model super(LlavaMPTModel, self).__init__(config) def embed_tokens(self, x): return self.wte(x) class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMPTConfig supports_gradient_checkpointing = True def __init__(self, config): super(MPTForCausalLM, self).__init__(config) if not config.tie_word_embeddings: raise ValueError('MPTForCausalLM only supports tied word embeddings') self.transformer = LlavaMPTModel(config) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == 'inv_sqrt_d_model': logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") self.logit_scale = logit_scale def get_model(self): return self.transformer def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlavaMPTModel): module.gradient_checkpointing = value def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None: raise NotImplementedError('inputs_embeds is not implemented for MPT yet') attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get('use_cache') == False: raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') else: prefix_mask = None return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} AutoConfig.register("llava_mpt", LlavaMPTConfig) AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) ================================================ FILE: llava/model/language_model/mpt/adapt_tokenizer.py ================================================ from typing import Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] NUM_SENTINEL_TOKENS: int = 100 def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): """Adds sentinel tokens and padding token (if missing). Expands the tokenizer vocabulary to include sentinel tokens used in mixture-of-denoiser tasks as well as a padding token. All added tokens are added as special tokens. No tokens are added if sentinel tokens and padding token already exist. """ sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] tokenizer.add_tokens(sentinels_to_add, special_tokens=True) if tokenizer.pad_token is None: tokenizer.add_tokens('', special_tokens=True) tokenizer.pad_token = '' assert tokenizer.pad_token_id is not None sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids tokenizer.sentinel_token_ids = _sentinel_token_ids class AutoTokenizerForMOD(AutoTokenizer): """AutoTokenizer + Adaptation for MOD. A simple wrapper around AutoTokenizer to make instantiating an MOD-adapted tokenizer a bit easier. MOD-adapted tokenizers have sentinel tokens (e.g., ), a padding token, and a property to get the token ids of the sentinel tokens. """ @classmethod def from_pretrained(cls, *args, **kwargs): """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) return tokenizer ================================================ FILE: llava/model/language_model/mpt/attention.py ================================================ """Attention layers.""" import math import warnings from typing import Optional import torch import torch.nn as nn from einops import rearrange from packaging import version from torch import nn from .norm import LPLayerNorm def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') else: return False return original_is_causal def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) if past_key_value is not None: if len(past_key_value) != 0: k = torch.cat([past_key_value[0], k], dim=3) v = torch.cat([past_key_value[1], v], dim=2) past_key_value = (k, v) (b, _, s_q, d) = q.shape s_k = k.size(-1) if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - s_q) _s_k = max(0, attn_bias.size(3) - s_k) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') attn_weight = attn_weight + attn_bias min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) causal_mask = causal_mask.tril() causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True) out = attn_weight.to(v.dtype).matmul(v) out = rearrange(out, 'b h s d -> b s (h d)') if needs_weights: return (out, attn_weight, past_key_value) return (out, None, past_key_value) def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') if not tensor.is_cuda: raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): try: from flash_attn import bert_padding, flash_attn_interface except: raise RuntimeError('Please install flash-attn==1.0.3.post0') check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1):] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) return (output, None, past_key_value) def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): try: from .flash_attn_triton import flash_attn_func except: _installed = False if version.parse(torch.__version__) < version.parse('2.0.0'): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func except: _installed = False if not _installed: raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.') if needs_weights: raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') if key_padding_mask is not None: warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) class MultiheadAttention(nn.Module): """Multi-head self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(self.d_model, device=device) self.k_ln = layernorm_class(self.d_model, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') elif self.attn_impl == 'torch': self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.chunk(3, dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights) return (self.out_proj(context), attn_weights, past_key_value) class MultiQueryAttention(nn.Module): """Multi-Query self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) fuse_splits = (d_model, d_model + self.head_dim) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(d_model, device=device) self.k_ln = layernorm_class(self.head_dim, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') elif self.attn_impl == 'torch': self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True) return (self.out_proj(context), attn_weights, past_key_value) def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: if alibi: if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) return (1, n_heads, 1, seq_len) elif prefix_lm or use_sequence_id: return (1, 1, seq_len, seq_len) return None else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: if alibi: (device, dtype) = (attn_bias.device, attn_bias.dtype) attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype)) return attn_bias else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') def gen_slopes(n_heads, alibi_bias_max=8, device=None): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) slopes = 1.0 / torch.pow(2, m) if _n_heads != n_heads: slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] return slopes.view(1, n_heads, 1, 1) def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) if full: alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} ================================================ FILE: llava/model/language_model/mpt/blocks.py ================================================ """GPT Blocks used for the GPT Model.""" from typing import Dict, Optional, Tuple import torch import torch.nn as nn from .attention import ATTN_CLASS_REGISTRY from .norm import NORM_CLASS_REGISTRY class MPTMLP(nn.Module): def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): super().__init__() self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) self.act = nn.GELU(approximate='none') self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj._is_residual = True def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) class MPTBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): del kwargs super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] self.norm_1 = norm_class(d_model, device=device) self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) self.norm_2 = norm_class(d_model, device=device) self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) return (x, attn_weights, past_key_value) ================================================ FILE: llava/model/language_model/mpt/configuration_mpt.py ================================================ """A HuggingFace-style model configuration.""" from typing import Dict, Optional, Union from transformers import PretrainedConfig attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} class MPTConfig(PretrainedConfig): model_type = 'mpt' def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): """The MPT configuration class. Args: d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. expansion_ratio (int): The ratio of the up/down scale in the MLP. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. emb_pdrop (float): The dropout probability for the embedding layer. learned_pos_emb (bool): Whether to use learned positional embeddings attn_config (Dict): A dictionary used to configure the model's attention module: attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention attn_pdrop (float): The dropout probability for the attention layers. attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, use the default scale of ``1/sqrt(d_keys)``. prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another bi-directionally. Tokens outside the prefix use causal attention. attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates which sub-sequence each token belongs to. Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. verbose (int): The verbosity level. 0 is silent. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use multiquery_attention (bool): Whether to use multiquery attention implementation. use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. init_std (float): The standard deviation of the normal distribution used to initialize the model, if using the baseline_ parameter initialization scheme. init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options """ self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.expansion_ratio = expansion_ratio self.max_seq_len = max_seq_len self.vocab_size = vocab_size self.resid_pdrop = resid_pdrop self.emb_pdrop = emb_pdrop self.learned_pos_emb = learned_pos_emb self.attn_config = attn_config self.init_device = init_device self.logit_scale = logit_scale self.no_bias = no_bias self.verbose = verbose self.embedding_fraction = embedding_fraction self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] super().__init__(**kwargs) self._validate_config() def _set_config_defaults(self, config, config_defaults): for (k, v) in config_defaults.items(): if k not in config: config[k] = v return config def _validate_config(self): self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) if self.d_model % self.n_heads != 0: raise ValueError('d_model must be divisible by n_heads') if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('alibi only implemented with torch and triton attention.') if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") if self.init_config.get('name', None) is None: raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") if not self.learned_pos_emb and (not self.attn_config['alibi']): raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') ================================================ FILE: llava/model/language_model/mpt/custom_embedding.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class SharedEmbedding(nn.Embedding): def forward(self, input: Tensor, unembed: bool=False) -> Tensor: if unembed: return F.linear(input, self.weight) return super().forward(input) ================================================ FILE: llava/model/language_model/mpt/flash_attn_triton.py ================================================ """ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py update imports to use 'triton_pre_mlir' *Experimental* implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. Caution: - This is an *experimental* implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident that there are none left for other head dimensions. Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn't. """ import math import torch import triton_pre_mlir as triton import triton_pre_mlir.language as tl @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']}) @triton.jit def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) if BIAS_TYPE == 'vector': b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n elif BIAS_TYPE == 'matrix': b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf')) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf')) if BIAS_TYPE != 'none': if BIAS_TYPE == 'vector': if EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == 'matrix': if EVEN_M & EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32) qk = qk * softmax_scale + bias m_ij = tl.maximum(tl.max(qk, 1), lse_i) p = tl.exp(qk - m_ij[:, None]) else: m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) acc_o_scale = tl.exp(m_i - m_ij) tl.store(t_ptrs, acc_o_scale) acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) p = p.to(v.dtype) acc_o += tl.dot(p, v) m_i = m_ij l_i_new = tl.exp(lse_i - m_ij) + l_ij lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) tl.store(t_ptrs, o_scale) o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @triton.jit def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr): if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) if BIAS_TYPE == 'vector': b_ptrs = Bias + offs_n elif BIAS_TYPE == 'matrix': b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) return if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m if EVEN_M & EVEN_HEADDIM: q = tl.load(q_ptrs) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) qk = tl.dot(q, k, trans_b=True) if not EVEN_N: qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf')) if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float('-inf')) if BIAS_TYPE != 'none': tl.debug_barrier() if BIAS_TYPE == 'vector': if EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == 'matrix': if EVEN_M & EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32) qk = qk * softmax_scale + bias if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) if BIAS_TYPE == 'none': p = tl.exp(qk * softmax_scale - lse_i[:, None]) else: p = tl.exp(qk - lse_i[:, None]) if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) else: do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) if not EVEN_HEADDIM: tl.debug_barrier() Di = tl.load(D + offs_m_curr) ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) dk += tl.dot(ds, q, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: dq = tl.load(dq_ptrs, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, eviction_policy='evict_last') elif EVEN_HEADDIM: dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy='evict_last') else: dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy='evict_last') else: dq = tl.dot(ds, k) if EVEN_M & EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq) elif EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) else: tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom if BIAS_TYPE == 'matrix': b_ptrs += BLOCK_M * stride_bm dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) def init_to_zero(name): return lambda nargs: nargs[name].zero_() @triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']) @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']}) @triton.jit def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads Q += off_b * stride_qb + off_h * stride_qh K += off_b * stride_kb + off_h * stride_kh V += off_b * stride_vb + off_h * stride_vh DO += off_b * stride_dob + off_h * stride_doh DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh if BIAS_TYPE != 'none': Bias += off_b * stride_bb + off_h * stride_bh D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) else: start_n = tl.program_id(0) _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert k.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d) assert d <= 128, 'FlashAttention only support head dimensions up to 128' assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(d) has_bias = bias is not None bias_type = 'none' if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 if bias.stride(-1) != 1: bias = bias.contiguous() if bias.shape[2:] == (1, seqlen_k): bias_type = 'vector' elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = 'matrix' else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)') bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK = 128 num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads) _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1) return (o, lse, softmax_scale) def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): if do.stride(-1) != 1: do = do.contiguous() (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(d) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads) _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM) has_bias = bias is not None bias_type = 'none' if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 assert bias.stride(-1) == 1 if bias.shape[2:] == (1, seqlen_k): bias_type = 'vector' elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = 'matrix' else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)') bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1, batch * nheads) _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM) dq.copy_(dq_accum) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): """ qkv: (batch, seqlen, 3, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) """ if qkv.stride(-1) != 1: qkv = qkv.contiguous() (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(qkv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (qkv, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dqkv = torch.empty_like(qkv) _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dqkv, None, None, None) flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): """ q: (batch, seqlen_q, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(q, kv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, kv, o, lse, bias) = ctx.saved_tensors if len(ctx.needs_input_grad) >= 3: assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dq, dkv, None, None, None) flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) k, v: (batch_size, seqlen_k, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(q, k, v, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, k, v, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dq, dk, dv, None, None, None) flash_attn_func = FlashAttnFunc.apply ================================================ FILE: llava/model/language_model/mpt/hf_prefixlm_converter.py ================================================ """Converts Huggingface Causal LM to Prefix LM. Conversion does lightweight surgery on a HuggingFace Causal LM to convert it to a Prefix LM. Prefix LMs accepts a `bidirectional_mask` input in `forward` and treat the input prompt as the prefix in `generate`. """ import math import warnings from types import MethodType from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES: """Converts a GPT-style Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, _SUPPORTED_GPT_MODELS) assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models' def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: """Helper that gets a list of the model's attention modules. Each module has a `bias` buffer used for causal masking. The Prefix LM conversion adds logic to dynamically manipulate these biases to support Prefix LM attention masking. """ attn_modules = [] if isinstance(model, GPTNeoXForCausalLM): blocks = model.gpt_neox.layers else: blocks = model.transformer.h for block in blocks: if isinstance(model, GPTNeoForCausalLM): if block.attn.attention_type != 'global': continue attn_module = block.attn.attention elif isinstance(model, GPTNeoXForCausalLM): attn_module = block.attention else: attn_module = block.attn attn_modules.append(attn_module) return attn_modules setattr(model, '_original_forward', getattr(model, 'forward')) setattr(model, '_original_generate', getattr(model, 'generate')) def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): """Wraps original forward to enable PrefixLM attention.""" def call_og_forward(): if isinstance(self, GPTNeoXForCausalLM): return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) else: return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) if bidirectional_mask is None: return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) attn_modules = _get_attn_modules(model) (b, s) = bidirectional_mask.shape max_length = attn_modules[0].bias.shape[-1] if s > max_length: raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).') assert s <= max_length if s < max_length: pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device) bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1) bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1) for attn_module in attn_modules: attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional) output = call_og_forward() for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM attention.""" attn_modules = _get_attn_modules(model) for attn_module in attn_modules: attn_module.bias.data[:] = 1 output = self._original_generate(*args, **kwargs) for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'generate', MethodType(generate, model)) setattr(model, '_prefix_lm_converted', True) return model def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM: """Converts a BLOOM Causal LM to a Prefix LM. Supported HuggingFace model classes: - `BloomForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, BloomForCausalLM) assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models' def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: combined_attention_mask = None device = attention_mask.device (_, src_length) = input_shape if src_length > 1: combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) if bidirectional_mask is not None: assert attention_mask.shape == bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask return combined_attention_mask def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: num_heads = self.config.n_head closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1) ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1) diffs = qa - ka + key_length - query_length diffs = -diffs.abs() alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length) alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length) return alibi.to(dtype) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop('position_ids', False) is not False: warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning) if len(deprecated_arguments) > 0: raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') elif input_ids is not None: (batch_size, seq_length) = input_ids.shape elif inputs_embeds is not None: (batch_size, seq_length, _) = inputs_embeds.shape else: raise ValueError('You have to specify either input_ids or inputs_embeds') if past_key_values is None: past_key_values = tuple([None] * len(self.h)) head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(inputs_embeds) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: tmp = past_key_values[0][0] past_key_values_length = tmp.shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if self.gradient_checkpointing and self.training: if use_cache: logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...') use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i]) else: outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: oa = (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + oa hidden_states = self.ln_f(hidden_states) if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if not return_dict: return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)) return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions) setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) setattr(model.transformer, 'forward', MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """Replacement forward method for BloomCausalLM.""" if deprecated_arguments.pop('position_ids', False) is not False: warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning) if len(deprecated_arguments) > 0: raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() (batch_size, seq_length, vocab_size) = shift_logits.shape loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions) def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict: if past: input_ids = input_ids[:, -1].unsqueeze(-1) bidirectional_mask = None if past[0][0].shape[0] == input_ids.shape[0]: past = self._convert_to_bloom_cache(past) else: bidirectional_mask = torch.ones_like(input_ids) return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask} setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model)) setattr(model, '_prefix_lm_converted', True) return model def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM: """Converts an OPT Causal LM to a Prefix LM. Supported HuggingFace model classes: - `OPTForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, OPTForCausalLM) assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models' setattr(model, '_original_forward', getattr(model, 'forward')) setattr(model, '_original_generate', getattr(model, 'generate')) model.model.decoder.bidirectional_mask = None def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): combined_attention_mask = None if input_shape[-1] > 1: if self.bidirectional_mask == 'g': (bsz, src_length) = input_shape combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) else: combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) if self.bidirectional_mask is not None: assert attention_mask.shape == self.bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) if attention_mask is not None: expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask return combined_attention_mask setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder)) def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): def call_og_forward(): return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) if bidirectional_mask is None: return call_og_forward() self.model.decoder.bidirectional_mask = bidirectional_mask try: outputs = call_og_forward() except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return outputs def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM-style attention.""" self.model.decoder.bidirectional_mask = 'g' try: output = self._original_generate(*args, **kwargs) except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return output setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'generate', MethodType(generate, model)) setattr(model, '_prefix_lm_converted', True) return model _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM) CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM] def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES: """Converts a HuggingFace Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` - `BloomForCausalLM` - `OPTForCausalLM` Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the `generate` method and/or select underlying methods depending on the model class. These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask". Notes on training: To actually train the converted model as a Prefix LM, training batches will need to indicate the prefix/target structure by including `bidirectional_mask` as part of the batch inputs. **This is not a standard input and requires custom layers either within or after your dataloader.** In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels` such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`. That is, the prefix portion of the sequence should not generate any loss. Loss should only be generated by the target portion of the sequence. Notes on `GPTNeoForCausalLM`: To simplify the implementation, "global" and "local" attention layers are handled differently. For "global" layers, we handle conversion as described above. For "local" layers, which use a causal attention mask within a restricted local window, we do not alter the masking. Notes on `forward` method conversion: After conversion, the `forward` method will handle a new input, `bidirectional_mask`, which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions belonging to the prefix (prefix tokens can attend to one another bidirectionally), and 0 indicates token positions belonging to the target. The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset the causal masks before returning the result. Notes on `generate` method conversion: After conversion, the `generate` method will have the same signature but will internally convert all causal masks to be purely bidirectional, call the original `generate` method, and (where appropriate) reset the causal masks before returning the result. This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and previously-generated tokens (also as expected in a Prefix LM). To preserve the API, the original methods are renamed to `_original_forward` and `_original_generate`, and replaced with new `forward` and `generate` methods that wrap them, respectively. Although implementation details vary by model class. """ if isinstance(model, _SUPPORTED_GPT_MODELS): return _convert_gpt_causal_lm_to_prefix_lm(model) elif isinstance(model, BloomForCausalLM): return _convert_bloom_causal_lm_to_prefix_lm(model) elif isinstance(model, OPTForCausalLM): return _convert_opt_causal_lm_to_prefix_lm(model) else: raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}') def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): """Attempts to add bidirectional_mask to batch if missing. Raises: KeyError if bidirectional_mask is missing and can't be inferred """ if 'bidirectional_mask' not in batch: if batch.get('mode', None) == 'icl_task': batch['bidirectional_mask'] = batch['attention_mask'].clone() for (i, continuation_indices) in enumerate(batch['continuation_indices']): batch['bidirectional_mask'][i, continuation_indices] = 0 elif 'labels' in batch and 'attention_mask' in batch: batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) else: raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') ================================================ FILE: llava/model/language_model/mpt/meta_init_context.py ================================================ from contextlib import contextmanager import torch import torch.nn as nn @contextmanager def init_empty_weights(include_buffers: bool=False): """Meta initialization context manager. A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM. Args: include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn # Initialize a model with 100 billions parameters in no time and without using any RAM. with init_empty_weights(): tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) ``` Any model created under this context manager has no weights. As such you can't do something like `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. """ with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: yield f @contextmanager def init_on_device(device: torch.device, include_buffers: bool=False): """Device initialization context manager. A context manager under which models are initialized with all parameters on the specified device. Args: device (`torch.device`): Device to initialize all parameters on. include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn with init_on_device(device=torch.device("cuda")): tst = nn.Liner(100, 100) # on `cuda` device ``` """ old_register_parameter = nn.Module.register_parameter if include_buffers: old_register_buffer = nn.Module.register_buffer def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) def register_empty_buffer(module, name, buffer): old_register_buffer(module, name, buffer) if buffer is not None: module._buffers[name] = module._buffers[name].to(device) if include_buffers: tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} else: tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): kwargs['device'] = device return fn(*args, **kwargs) return wrapper try: nn.Module.register_parameter = register_empty_parameter if include_buffers: nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) yield finally: nn.Module.register_parameter = old_register_parameter if include_buffers: nn.Module.register_buffer = old_register_buffer for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): setattr(torch, torch_function_name, old_torch_function) ================================================ FILE: llava/model/language_model/mpt/modeling_mpt.py ================================================ """A simple, flexible implementation of a GPT model. Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .attention import attn_bias_shape, build_attn_bias from .blocks import MPTBlock from .custom_embedding import SharedEmbedding from .norm import NORM_CLASS_REGISTRY from .configuration_mpt import MPTConfig from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm from .meta_init_context import init_empty_weights from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ try: from .flash_attn_triton import flash_attn_func except: pass Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' _no_split_modules = ['MPTBlock'] class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): config._validate_config() super().__init__(config) self.attn_impl = config.attn_config['attn_impl'] self.prefix_lm = config.attn_config['prefix_lm'] self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] self.alibi = config.attn_config['alibi'] self.alibi_bias_max = config.attn_config['alibi_bias_max'] if config.init_device == 'mixed': if dist.get_local_rank() == 0: config.init_device = 'cpu' else: config.init_device = 'meta' if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] self.embedding_fraction = config.embedding_fraction self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device) if not self.alibi: self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device) self.emb_drop = nn.Dropout(config.emb_pdrop) self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]) self.norm_f = norm_class(config.d_model, device=config.init_device) if config.init_device != 'meta': print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.') self.apply(self.param_init_fn) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id) if config.no_bias: for module in self.modules(): if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter): if config.verbose: warnings.warn(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) if config.verbose and config.verbose > 2: print(self) if 'verbose' not in self.config.init_config: self.config.init_config['verbose'] = self.config.verbose if self.config.init_config['verbose'] > 1: init_fn_name = self.config.init_config['name'] warnings.warn(f'Using {init_fn_name} initialization.') self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value @torch.no_grad() def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) self._attn_bias_initialized = True if self.attn_impl == 'flash': return (self.attn_bias, attention_mask) if self.attn_bias is not None: self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, torch.Tensor) assert isinstance(prefix_mask, torch.Tensor) attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: s_k = attention_mask.shape[-1] if attn_bias is None: attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) else: _s_k = max(0, attn_bias.size(-1) - s_k) attn_bias = attn_bias[:, :, :, _s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): (s_k, s_q) = attn_bias.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') attn_bias = attn_bias[..., :seq_len, :seq_len] causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: attention_mask = attention_mask.bool() if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: raise NotImplementedError('return_dict False is not implemented yet for MPT') if output_attentions: if self.attn_impl != 'torch': raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: raise NotImplementedError('MPT does not support training with left padding.') if self.prefix_lm and prefix_mask is None: raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') elif self.attn_uses_sequence_id is False and sequence_id is not None: warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') if input_ids is not None: S = input_ids.size(1) assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' tok_emb = self.wte(input_ids) else: assert inputs_embeds is not None assert self.alibi, 'inputs_embeds is not implemented for MPT unless for alibi.' S = inputs_embeds.size(1) tok_emb = inputs_embeds if self.alibi: x = tok_emb else: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') past_position = past_key_values[0][0].size(1) if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) if S + past_position > self.config.max_seq_len: raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) if attention_mask is not None: pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) pos_emb = self.wpe(pos) x = tok_emb + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) else: x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction) assert isinstance(self.emb_drop, nn.Module) x = self.emb_drop(x_shrunk) (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for (b_idx, block) in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) past_key_value = past_key_values[b_idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(block, x, past_key_value, attn_bias, attention_mask, self.is_causal) else: (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) if past_key_values is not None: past_key_values[b_idx] = past_key_value if output_attentions: assert all_self_attns is not None all_self_attns = all_self_attns + (attn_weights,) x = self.norm_f(x) if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns) def param_init_fn(self, module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) if not config.tie_word_embeddings: raise ValueError('MPTForCausalLM only supports tied word embeddings') print(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer = MPTModel(config) for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue if isinstance(child, torch.nn.Module): child._fsdp_wrap = True self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == 'inv_sqrt_d_model': logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") self.logit_scale = logit_scale def get_input_embeddings(self): return self.transformer.wte def set_input_embeddings(self, value): self.transformer.wte = value def get_output_embeddings(self): return self.transformer.wte def set_output_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings def set_decoder(self, decoder): self.transformer = decoder def get_decoder(self): return self.transformer def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if inputs_embeds is not None: raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).') outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def param_init_fn(self, module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None: raise NotImplementedError('inputs_embeds is not implemented for MPT yet') attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get('use_cache') == False: raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') else: prefix_mask = None return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} @staticmethod def _reorder_cache(past_key_values, beam_idx): """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 for an example in transformers. """ reordered_past = [] for layer_past in past_key_values: reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] return reordered_past ================================================ FILE: llava/model/language_model/mpt/norm.py ================================================ import torch def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == 'cuda': dtype = torch.get_autocast_gpu_dtype() elif tensor.device.type == 'cpu': dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor class LPLayerNorm(torch.nn.LayerNorm): def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) def rms_norm(x, weight=None, eps=1e-05): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight return output class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__() self.eps = eps if weight: self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) else: self.register_parameter('weight', None) def forward(self, x): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) class LPRMSNorm(RMSNorm): def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} ================================================ FILE: llava/model/language_model/mpt/param_init_fns.py ================================================ import math import warnings from collections.abc import Sequence from functools import partial from typing import Optional, Tuple, Union import torch from torch import nn from .norm import NORM_CLASS_REGISTRY def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f"Initializing network using module's reset_parameters attribute") if hasattr(module, 'reset_parameters'): module.reset_parameters() def fused_init_helper_(module: nn.Module, init_fn_): _fused = getattr(module, '_fused', None) if _fused is None: raise RuntimeError(f'Internal logic error') (dim, splits) = _fused splits = (0, *splits, module.weight.size(dim)) for (s, e) in zip(splits[:-1], splits[1:]): slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) init_fn_(module.weight[slice_indices]) def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'If model has bias parameters they are initialized to 0.') init_div_is_residual = init_div_is_residual if init_div_is_residual is False: div_is_residual = 1.0 elif init_div_is_residual is True: div_is_residual = math.sqrt(2 * n_layers) elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): div_is_residual = init_div_is_residual elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): div_is_residual = float(init_div_is_residual) else: div_is_residual = 1.0 raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') if init_div_is_residual is not False: if verbose > 1: warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') if isinstance(module, nn.Linear): if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) else: init_fn_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) if init_div_is_residual is not False and getattr(module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) elif isinstance(module, nn.Embedding): if emb_init_std is not None: std = emb_init_std if std == 0: warnings.warn(f'Embedding layer initialized to 0.') emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) if verbose > 1: warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') elif emb_init_uniform_lim is not None: lim = emb_init_uniform_lim if isinstance(lim, Sequence): if len(lim) > 2: raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') if lim[0] == lim[1]: warnings.warn(f'Embedding layer initialized to {lim[0]}.') else: if lim == 0: warnings.warn(f'Embedding layer initialized to 0.') lim = [-lim, lim] (a, b) = lim emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) if verbose > 1: warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') else: emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): if verbose > 1: warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') if hasattr(module, 'weight') and module.weight is not None: torch.nn.init.ones_(module.weight) if hasattr(module, 'bias') and module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): if module._qkv_same_embed_dim: assert module.in_proj_weight is not None assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) assert d_model is not None _d = d_model splits = (0, _d, 2 * _d, 3 * _d) for (s, e) in zip(splits[:-1], splits[1:]): init_fn_(module.in_proj_weight[s:e]) else: assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) assert module.in_proj_weight is None init_fn_(module.q_proj_weight) init_fn_(module.k_proj_weight) init_fn_(module.v_proj_weight) if module.in_proj_bias is not None: torch.nn.init.zeros_(module.in_proj_bias) if module.bias_k is not None: torch.nn.init.zeros_(module.bias_k) if module.bias_v is not None: torch.nn.init.zeros_(module.bias_v) init_fn_(module.out_proj.weight) if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): with torch.no_grad(): module.out_proj.weight.div_(div_is_residual) if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) else: for _ in module.parameters(recurse=False): raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') def _normal_init_(std, mean=0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs init_fn_ = _normal_init_(std=std) if verbose > 1: warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs if init_std is None: raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs std = math.sqrt(2 / (5 * d_model)) _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): """From section 2.3.1 of GPT-NeoX-20B: An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py """ del kwargs residual_div = n_layers / math.sqrt(10) if verbose > 1: warnings.warn(f'setting init_div_is_residual to {residual_div}') small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): del kwargs xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) if verbose > 1: warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) if verbose > 1: warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} ================================================ FILE: llava/model/llava_arch.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn as nn from .multimodal_encoder.builder import build_vision_tower from .openseed import build_model from .openseed.BaseModel import BaseModel grounding_start="" grounding_end="" SEG_TOKEN="" from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN class LlavaMetaModel: def __init__(self, config): super(LlavaMetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter self.config.mm_vision_tower = vision_tower vision_tower = build_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower self.config.use_mm_proj = True self.config.mm_hidden_size = vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature if not hasattr(self, 'mm_projector'): self.mm_projector = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] # inp_embs = self.get_input_embeddings().weight.data[nums] # out_embs = self.get_output_embeddings().weight.data[nums] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: # print("Emb length:", len(self.get_input_embeddings().weight.data)) # if len(self.get_input_embeddings().weight.data) > 0: # if len(self.get_input_embeddings().weight.data) > 0: # self.get_input_embeddings().weight.data[-num_new_tokens:] = inp_embs # self.get_output_embeddings().weight.data[-num_new_tokens:] = out_embs input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True class LlavaMetaForCausalLM_gd(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()["weight"].dtype)) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True def initialize_seg_modules(self, cfg): seg_model = BaseModel(cfg, build_model(cfg)) seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.seg_model = seg_model def freeze_seg_modules(self): for p in self.seg_model.parameters(): p.requires_grad = False class LlavaMetaForCausalLM_gd_interactive(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()["weight"].dtype)) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images,obj_feats=None,num_it=0 ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if batch_idx >= len(input_ids) - num_it: obj_idx = cur_input_ids == 1273 idx_in_inter=batch_idx-(len(input_ids)-num_it) cur_new_input_embeds[-1][obj_idx] = obj_feats[idx_in_inter].to(cur_new_input_embeds[-1].dtype) if labels is not None: cur_labels[cur_labels==1273]=IGNORE_INDEX cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def prepare_inputs_labels_for_multimodal_NoInter( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] # inp_embs = self.get_input_embeddings().weight.data[nums] # out_embs = self.get_output_embeddings().weight.data[nums] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True def initialize_seg_modules(self, cfg): seg_model = BaseModel(cfg, build_model(cfg)) seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.seg_model = seg_model def initialize_interactive_modules(self, cfg): from .semsam.BaseModel import BaseModel as SemSamBaseModel from .semsam import build_model as build_semsam_model seg_model = SemSamBaseModel(cfg, build_semsam_model(cfg)) if not (cfg.MODEL.WEIGHTS == "None"): seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.interactive_model = seg_model def freeze_seg_modules(self): for p in self.seg_model.parameters(): p.requires_grad = False ================================================ FILE: llava/model/make_delta.py ================================================ """ Usage: python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model.utils import auto_upgrade def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading target model") auto_upgrade(target_model_path) target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Calculating delta") for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data -= base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam print("Saving delta") if hub_repo_id: kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} else: kwargs = {} target.save_pretrained(delta_path, **kwargs) target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) target_tokenizer.save_pretrained(delta_path, **kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) parser.add_argument("--hub-repo-id", type=str, default=None) args = parser.parse_args() make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) ================================================ FILE: llava/model/multimodal_encoder/builder.py ================================================ from .clip_encoder import CLIPVisionTower def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) if vision_tower.startswith("openai") or vision_tower.startswith("laion"): return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f'Unknown vision tower: {vision_tower}') ================================================ FILE: llava/model/multimodal_encoder/clip_encoder.py ================================================ import torch import torch.nn as nn from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') if not delay_load: self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) def load_model(self): self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/") self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 ================================================ FILE: llava/model/openseed/BaseModel.py ================================================ import os import logging import torch import torch.nn as nn # from utils.model import align_and_update_state_dicts logger = logging.getLogger(__name__) def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): model_keys = sorted(model_state_dict.keys()) ckpt_keys = sorted(ckpt_state_dict.keys()) result_dicts = {} matched_log = [] unmatched_log = [] unloaded_log = [] for model_key in model_keys: model_weight = model_state_dict[model_key] if model_key in ckpt_keys: ckpt_weight = ckpt_state_dict[model_key] if model_weight.shape == ckpt_weight.shape: result_dicts[model_key] = ckpt_weight ckpt_keys.pop(ckpt_keys.index(model_key)) matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unmatched_log.append( "*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) # if is_main_process(): # for info in matched_log: # logger.info(info) # for info in unloaded_log: # logger.warning(info) # for key in ckpt_keys: # logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) # for info in unmatched_log: # logger.warning(info) return result_dicts class BaseModel(nn.Module): def __init__(self, opt, module: nn.Module): super(BaseModel, self).__init__() self.opt = opt self.model = module def forward(self, *inputs, **kwargs): outputs = self.model(*inputs, **kwargs) return outputs def save_pretrained(self, save_dir): torch.save(self.model.state_dict(), save_dir) def from_pretrained(self, load_dir): state_dict = torch.load(load_dir, map_location='cpu') state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) self.model.load_state_dict(state_dict, strict=False) return self ================================================ FILE: llava/model/openseed/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .architectures import build_model ================================================ FILE: llava/model/openseed/architectures/__init__.py ================================================ from .openseed_model import * # from .openseed_model_decouple_train import * from .build import build_model ================================================ FILE: llava/model/openseed/architectures/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_model(config, **kwargs): model_name = config['MODEL']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/architectures/openseed_model.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Tuple import torch from torch import nn from torch.nn import functional as F from .registry import register_model from ..utils import configurable, box_ops #, get_class_names from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import random class OpenSeeD(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, criterion: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, merge_class=False, coco_only=False, detach_seg=False, eval_train=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.criterion = criterion self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata if size_divisibility < 0: # use backbone size_divisibility if not set size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.detach_seg=detach_seg self.eval_train=eval_train # additional args self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on} self.num_correct_gd=0 self.num_total_gd=0 self.num_correct_ref = 0 self.num_total_ref = 0 self.num_correct_coco = 0 self.num_total_coco = 0 self.coco_only=coco_only self.loss_dict=None self.mean_iou=0.0 ######## self.total_union=0.0 self.total_intersection=0.0 # self.cIoU=0.0 print("self.task_switch ", self.task_switch) # HACK for only two datasets for seg and det if not self.semantic_on: assert self.sem_seg_postprocess_before_inference @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] # Loss parameters: deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0":box_weight,"loss_giou_0":giou_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] dn = 'no' # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k!="loss_mask" and k!="loss_dice" }) dn_losses=["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses=["dn_labels", "masks", "boxes"] else: dn_losses=[] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["labels", "masks","boxes"] else: losses = ["labels", "masks"] # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10), 'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)} weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0) weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()} # building criterion criterion = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'], ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) # lang_encoder = build_language_encoder(cfg) sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion": criterion, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "coco_only": dec_cfg.get('COCO_ONLY', False), "detach_seg": cfg.get('detach_seg', False), "eval_train": cfg.get('eval_train', False), } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs, inference_task='seg'): # import ipdb; ipdb.set_trace() # print("Num images per batch:",len(batched_inputs['flickr'])) if self.training: losses = {} losses_ = dict() if 'flickr' in batched_inputs and not self.coco_only: self.criterion.conversation=False losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg',default_text_embeddings=batched_inputs['flickr_text_embeddings'],data_type='gd') for key, value in losses_flickr.items(): losses_['flickr.'+str(key)] = losses_flickr[key] self.loss_dict=losses_flickr if 'refcoco' in batched_inputs and not self.coco_only: self.criterion.conversation=False losses_ref = self.forward_seg(batched_inputs['refcoco'], task='seg',default_text_embeddings=batched_inputs['refcoco_text_embeddings'],data_type='ref') for key, value in losses_ref.items(): losses_['refcoco.'+str(key)] = losses_ref[key] if 'vg' in batched_inputs and not self.coco_only: self.criterion.conversation = False losses_ref = self.forward_seg(batched_inputs['vg'], task='det', default_text_embeddings=batched_inputs['vg_text_embeddings'], data_type='ref') for key, value in losses_ref.items(): losses_['vg.' + str(key)] = losses_ref[key] # self.loss_dict=losses_flickr if 'coco' in batched_inputs: # if self.loss_dict is None: # # else: valid_idx=[] for idx,input in enumerate(batched_inputs['coco']): if input['grounding']: valid_idx.append(idx) if len(valid_idx)==0: self.criterion.conversation = True losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg', default_text_embeddings=batched_inputs[ 'flickr_text_embeddings'], data_type='coco') self.loss_dict = losses_flickr for key, value in self.loss_dict.items(): losses['coco.' + str(key)] = self.loss_dict[key] * 0.0 else: batched_inputs['coco']=[batched_inputs['coco'][idx] for idx in valid_idx] text_embed=batched_inputs['coco_text_embeddings'] text_embed=text_embed[0][valid_idx],text_embed[1][valid_idx] self.criterion.conversation = True losses_coco_instruct = self.forward_seg(batched_inputs['coco'], task='seg',default_text_embeddings=text_embed) for key, value in losses_coco_instruct.items(): losses['coco.'+str(key)] = losses_coco_instruct[key] losses.update(losses_) # if self.task_switch['coco']: # self.criterion.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80 # # self.criterion.num_classes = 133 # task = 'seg' # if not self.coco_mask_on: # task = 'det' # # import ipdb; ipdb.set_trace() # losses_coco = self.forward_seg(batched_inputs['coco'], task=task) # new_losses_coco = {} # for key, value in losses_coco.items(): # new_losses_coco['coco.'+str(key)] = losses_coco[key] # losses.update(new_losses_coco) # if self.task_switch['o365']: # self.criterion.num_classes = 365 # losses_o365 = self.forward_seg(batched_inputs['o365'], task='det') # new_losses_o365 = {} # for key, value in losses_o365.items(): # new_losses_o365['o365.'+str(key)] = losses_o365[key] # losses.update(new_losses_o365) return losses else: processed_results = self.forward_seg(batched_inputs, task=inference_task) return processed_results def forward_seg(self, batched_inputs, task='seg',default_text_embeddings=None,data_type='gd'): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) # features={k:v.to(torch.bfloat16) for k,v in features.items()} if self.training: # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings) ##########eval training if self.eval_train: pred_logits=outputs["pred_logits"] pred_boxes=outputs["pred_boxes"] pred_masks=outputs["pred_masks"]>0 num_total=0 num_correct=0 mask_iou=0.0 # total_union=0.0 # total_intersection=0.0 scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs] for i in range(len(pred_logits)): matched_idx=torch.argmax(pred_logits[i],dim=0) matched_boxes=pred_boxes[i][matched_idx] matched_masks=pred_masks[i][matched_idx] gt_boxes_=targets[i]['boxes'] gt_masks_=targets[i]['masks'] gt_ground_labels=targets[i]['labels'] gt_ground_labels_=[] for lb in gt_ground_labels: gt_ground_labels_.extend(lb) max_lb=max(gt_ground_labels_) lb2gt_idx=dict() for lb in range(max_lb+1): lb2gt_idx[lb]=[] for idx,lbs in enumerate(gt_ground_labels): for lb in lbs: lb2gt_idx[lb].append(idx) for lb in range(max_lb+1): pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None]) gt_boxes=box_ops.box_cxcywh_to_xyxy(gt_boxes_[lb2gt_idx[lb]]) pred_mask=matched_masks[lb] gt_mask=gt_masks_[lb2gt_idx[lb]][0] pred_mask = F.interpolate( pred_mask[None,None].float(), size=(gt_mask.shape[-2], gt_mask.shape[-1]), mode="bilinear", align_corners=False, )[0,0]>0.5 if len(gt_boxes)==0: continue mask_iou+=float(torch.sum(pred_mask*gt_mask)/torch.sum(torch.logical_or(pred_mask,gt_mask))) self.total_union+=float(torch.sum(torch.logical_or(pred_mask,gt_mask)))/scale_factor[i]**2 self.total_intersection+=float(torch.sum(pred_mask*gt_mask))/scale_factor[i]**2 # self.mask_iou+=mask_iou if box_ops.box_iou(pred_box,gt_boxes)[0].max()>0.5: num_correct+=1 else: pass num_total+=1 print(f"{data_type} cIoU:" ,self.total_intersection/self.total_union) name_correct='num_correct_'+data_type name_total='num_total_'+data_type try: gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,num_correct) # self.num_correct+=sum(gathered_list) num_correct_value=getattr(self,name_correct) setattr(self,name_correct,num_correct_value+sum(gathered_list)) gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,num_total) # self.num_total+=sum(gathered_list) num_total_value=getattr(self,name_total) setattr(self,name_total,num_total_value+sum(gathered_list)) gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,mask_iou) # self.mask_iou+=sum(gathered_list) if torch.distributed.get_rank()==0: print(f"{data_type} acc: ",getattr(self, name_correct) / getattr(self, name_total)) # print("mask_iou: ",self.mask_iou/self.num_total) except Exception as e: # self.num_correct+=num_correct # self.num_total+=num_total num_correct_value = getattr(self, name_correct) setattr(self, name_correct, num_correct_value + num_correct) num_total_value = getattr(self, name_total) setattr(self, name_total, num_total_value + num_total) try: print(f"{data_type} rank{torch.distributed.get_rank()} acc: ", getattr(self, name_correct) / getattr(self, name_total)) except Exception as e: print(f"{data_type} acc: ", getattr(self, name_correct) / getattr(self, name_total)) ########################### # bipartite matching-based loss self.criterion.default_text_embeddings = default_text_embeddings losses = self.criterion(outputs, targets, mask_dict, task=task) for k in list(losses.keys()): if k in self.criterion.weight_dict: losses[k] *= self.criterion.weight_dict[k] else: # remove this loss if not specified in `weight_dict` losses.pop(k) return losses else: outputs, _ = self.sem_seg_head(features) mask_cls_results = outputs["pred_logits"] mask_box_results = outputs["pred_boxes"] if 'seg' in task: if task == 'seg': self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True if task == 'inst_seg': self.semantic_on = self.panoptic_on = False self.instance_on = True self.sem_seg_postprocess_before_inference = True if task == 'sem_pan_seg': self.semantic_on = self.panoptic_on = True self.instance_on = False self.sem_seg_postprocess_before_inference = True if task == 'inst_pan_seg': self.instance_on = self.panoptic_on = True self.semantic_on = False self.sem_seg_postprocess_before_inference = True if task == 'sem_seg': self.instance_on = self.panoptic_on = False self.semantic_on = True self.sem_seg_postprocess_before_inference = True mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) else: self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False self.instance_on = True mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results) del outputs processed_results = [] for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) processed_results.append({}) new_size = (images.tensor.shape[-2], images.tensor.shape[-1]) # padded size (divisible to 32) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width ) mask_cls_result = mask_cls_result.to(mask_pred_result) # semantic segmentation inference if self.semantic_on: r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result) if not self.sem_seg_postprocess_before_inference: r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) processed_results[-1]["sem_seg"] = r # panoptic segmentation inference if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) processed_results[-1]["panoptic_seg"] = panoptic_r # instance segmentation inference if self.instance_on: mask_box_result = mask_box_result.to(mask_pred_result) height = new_size[0]/image_size[0]*height width = new_size[1]/image_size[1]*width mask_box_result = self.box_postprocess(mask_box_result, height, width) instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result) processed_results[-1]["instances"] = instance_r del mask_pred_results return processed_results def prepare_targets(self, targets, images, task='seg'): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) if task != 'det': gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks else: padded_masks = None new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, "boxes":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy } ) return new_targets def semantic_inference(self, mask_cls, mask_pred): # if use cross-entropy loss in training, evaluate with softmax if self.semantic_ce_loss: mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. else: T = self.pano_temp mask_cls = mask_cls.sigmoid() if self.transform_eval: mask_cls = F.softmax(mask_cls / T, dim=-1) # already sigmoid mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def panoptic_inference(self, mask_cls, mask_pred): # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. prob = 0.5 T = self.pano_temp scores, labels = mask_cls.sigmoid().max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) # added process if self.transform_eval: scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: # We didn't detect any mask :( return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= prob).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= prob) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def instance_inference(self, mask_cls, mask_pred, mask_box_result): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = mask_cls.sigmoid() # [100, 80] labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # select 100 labels_per_image = labels[topk_indices] topk_indices = topk_indices // self.sem_seg_head.num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): # print(i, len(keep), lab) keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values() scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() # half mask box half pred box mask_box_result = mask_box_result[topk_indices] if self.panoptic_on: mask_box_result = mask_box_result[keep] result.pred_boxes = Boxes(mask_box_result) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob if self.sem_seg_postprocess_before_inference: mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) else: mask_scores_per_image = 1.0 # labels_per_image = labels_per_image + 1 # HACK for o365 classification if self.focus_on_box: mask_scores_per_image = 1.0 result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def box_postprocess(self, out_bbox, img_h, img_w): # postprocess box height and width boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) scale_fct = torch.tensor([img_w, img_h, img_w, img_h]) scale_fct = scale_fct.to(out_bbox) boxes = boxes * scale_fct return boxes def forward_eval(self, batched_inputs, text_embeddings): # import ipdb; ipdb.set_trace() # print("Num images per batch:",len(batched_inputs['flickr'])) if self.training: raise NotImplementedError else: self.criterion.conversation=False box_results, seg_results = self.forward_inner_eval( batched_inputs, task='seg', default_text_embeddings=text_embeddings, ) return box_results, seg_results def forward_inner_eval(self, batched_inputs, task='seg',default_text_embeddings=None): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) matching_threshold = batched_inputs[0]["matching_threshold"] if "matching_threshold" in batched_inputs[0].keys() else None features = self.backbone(images.tensor) # features={k:v.to(torch.bfloat16) for k,v in features.items()} # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None default_text_embeddings_ = [default_text_embeddings[0].float(), default_text_embeddings[1]] outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings_) ##########eval training pred_logits=outputs["pred_logits"] pred_boxes=outputs["pred_boxes"] pred_masks=outputs["pred_masks"]>0 # scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs] matched_pred_boxes = [] matched_pred_masks = [] for i in range(len(pred_logits)): if len(pred_logits) > 1: raise NotImplementedError num_grounding = pred_logits.shape[2] for gd_idx in range(num_grounding): if matching_threshold is None: matched_idx = torch.argmax(pred_logits[i, :, gd_idx],dim=0) matched_boxes = pred_boxes[i][matched_idx] matched_boxes = matched_boxes[None, :] else: matched_idx = torch.where(pred_logits[i, :, gd_idx].softmax(dim=0) > matching_threshold)[0] # print(matched_idx, pred_logits[i, :, gd_idx].softmax(dim=0)[matched_idx]) if matched_idx.shape[0] == 0: #* if there is no one object satisfy threshold, then select the best matched one. matched_boxes = pred_boxes.new_zeros((1, 4)) matched_masks = pred_boxes.new_zeros((1, 256, 256)) else: matched_boxes = pred_boxes[i][matched_idx] matched_masks = pred_masks[i][matched_idx] # matched_masks=pred_masks[i][matched_idx] matched_boxes_processed = [] for lb in range(matched_boxes.shape[0]): pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None]) matched_boxes_processed.append(pred_box) matched_pred_boxes.append(torch.cat(matched_boxes_processed, dim=0)) matched_pred_masks.append(matched_masks) return matched_pred_boxes, matched_pred_masks @register_model def get_segmentation_model(cfg, **kwargs): return OpenSeeD(cfg) ================================================ FILE: llava/model/openseed/architectures/openseed_model_decouple_train.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ from typing import Tuple import torch from torch import nn from torch.nn import functional as F from .registry import register_model from ..utils import configurable, box_ops, get_class_names from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion from ..language import build_language_encoder from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import random import json class OpenSeeD(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, criterion_coco=None, criterion_o365=None, split_panno=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata if size_divisibility < 0: # use backbone size_divisibility if not set size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.split_panno=split_panno # additional args self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on} self.criterion_coco=criterion_coco self.criterion_o365=criterion_o365 print("self.task_switch ", self.task_switch) # HACK for only two datasets for seg and det if coco_on: task = 'seg' if not coco_mask_on: task = 'det' self.train_class_names[task] = get_class_names(train_dataset_name[0], background=background) self.train_class_names[task] = [a.replace("-merged", "").replace("-other", "").replace("-stuff", "") for a in self.train_class_names[task]] train_class_names = [] for name in self.train_class_names[task]: names = name.split('-') if len(names) > 1: assert len(names) == 2 train_class_names.append(names[1] + ' ' + names[0]) else: train_class_names.append(name) self.train_class_names[task] = train_class_names if o365_on and len(train_dataset_name)>1: for dt in train_dataset_name: if "o365" in train_dataset_name or "object365" in train_dataset_name: break self.train_class_names['det'] = get_class_names(dt, background=background) self.train_class_names['det'] = [a.lower().split('/') for a in self.train_class_names['det']] if not self.semantic_on: assert self.sem_seg_postprocess_before_inference @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] # Loss parameters: deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0":box_weight,"loss_giou_0":giou_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k!="loss_mask" and k!="loss_dice" }) dn_losses=["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses=["labels", "masks", "boxes"] else: dn_losses=[] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["labels", "masks","boxes"] else: losses = ["labels", "masks"] # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10), 'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)} # building criterion criterion_coco = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'], ) criterion_o365 = SetCriterion( enc_cfg.get('NUM_CLASSES_O365', 365), matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][ 'PANOPTIC_ON'], ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) lang_encoder = build_language_encoder(cfg) sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), lang_encoder, extra=extra) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion_coco": criterion_coco, "criterion_o365": criterion_o365, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "split_panno": dec_cfg.get('PANO_CRITERION', True), } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs, inference_task='seg'): # import ipdb; ipdb.set_trace() if self.training: losses = {} if self.task_switch['coco'] and 'coco' in batched_inputs: self.criterion_coco.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80 # self.criterion.num_classes = 133 task = 'seg' if not self.coco_mask_on: task = 'det' # import ipdb; ipdb.set_trace() losses_coco = self.forward_seg(batched_inputs['coco'], task=task) new_losses_coco = {} for key, value in losses_coco.items(): new_losses_coco['coco.'+str(key)] = losses_coco[key] losses.update(new_losses_coco) if self.task_switch['o365'] and 'o365' in batched_inputs: self.criterion_o365.num_classes = 365 losses_o365 = self.forward_seg(batched_inputs['o365'], task='det') new_losses_o365 = {} for key, value in losses_o365.items(): new_losses_o365['o365.'+str(key)] = losses_o365[key] losses.update(new_losses_o365) return losses else: processed_results = self.forward_seg(batched_inputs, task=inference_task) return processed_results def forward_seg(self, batched_inputs, task='seg'): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": per-region ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict has the results for one image. The dict contains the following keys: * "sem_seg": A Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel. * "panoptic_seg": A tuple that represent panoptic output panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict]): Describe each segment in `panoptic_seg`. Each dict contains keys "id", "category_id", "isthing". """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) if self.training: if task == "det" and self.task_switch['o365']: train_class_names = [random.sample(name, 1)[0] for name in self.train_class_names['det']] else: train_class_names = self.train_class_names[task] self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(train_class_names, is_eval=False) # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None outputs, mask_dict = self.sem_seg_head(features, targets=targets, task=task) # bipartite matching-based loss if task=='det': criterion=self.criterion_o365 losses = self.criterion_o365(outputs, targets, mask_dict, task=task) else: criterion=self.criterion_coco losses = self.criterion_coco(outputs, targets, mask_dict, task=task) # else for k in list(losses.keys()): if k in criterion.weight_dict: losses[k] *= criterion.weight_dict[k] else: # remove this loss if not specified in `weight_dict` losses.pop(k) return losses else: outputs, _ = self.sem_seg_head(features) mask_cls_results = outputs["pred_logits"] mask_box_results = outputs["pred_boxes"] if 'seg' in task: if task == 'seg': self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True if task == 'pan_seg': self.semantic_on = self.instance_on = False self.panoptic_on = True self.sem_seg_postprocess_before_inference = True if task == 'inst_seg': self.semantic_on = self.panoptic_on = False self.instance_on = True self.sem_seg_postprocess_before_inference = True if task == 'sem_pan_seg': self.semantic_on = self.panoptic_on = True self.instance_on = False self.sem_seg_postprocess_before_inference = True if task == 'inst_pan_seg': self.instance_on = self.panoptic_on = True self.semantic_on = False self.sem_seg_postprocess_before_inference = True if task == 'sem_seg': self.instance_on = self.panoptic_on = False self.semantic_on = True self.sem_seg_postprocess_before_inference = True mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) else: self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False self.instance_on = True mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results) del outputs processed_results = [] for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) processed_results.append({}) new_size = (images.tensor.shape[-2], images.tensor.shape[-1]) # padded size (divisible to 32) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width ) mask_cls_result = mask_cls_result.to(mask_pred_result) # semantic segmentation inference if self.semantic_on: r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result) if not self.sem_seg_postprocess_before_inference: r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) processed_results[-1]["sem_seg"] = r # panoptic segmentation inference if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) processed_results[-1]["panoptic_seg"] = panoptic_r # instance segmentation inference if self.instance_on: mask_box_result = mask_box_result.to(mask_pred_result) height = new_size[0]/image_size[0]*height width = new_size[1]/image_size[1]*width mask_box_result = self.box_postprocess(mask_box_result, height, width) instance_r = retry_if_cuda_oom(self.instance_inference)( mask_cls_result[:self.sem_seg_head.predictor.num_queries_test], mask_pred_result[:self.sem_seg_head.predictor.num_queries_test], mask_box_result[:self.sem_seg_head.predictor.num_queries_test], True) # instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result) processed_results[-1]["instances"] = instance_r del mask_pred_results return processed_results def prepare_targets(self, targets, images, task='seg'): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) if task != 'det': gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks else: padded_masks = None new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, "boxes":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy } ) return new_targets def semantic_inference(self, mask_cls, mask_pred): # if use cross-entropy loss in training, evaluate with softmax if self.semantic_ce_loss: mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. else: T = self.pano_temp mask_cls = mask_cls.sigmoid() if self.transform_eval: mask_cls = F.softmax(mask_cls / T, dim=-1) # already sigmoid mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def panoptic_inference(self, mask_cls, mask_pred): # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. prob = 0.5 T = self.pano_temp scores, labels = mask_cls.sigmoid().max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) # added process if self.transform_eval: scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: # We didn't detect any mask :( return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= prob).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= prob) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def instance_inference(self, mask_cls, mask_pred, mask_box_result,split_anno): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = mask_cls.sigmoid() # [100, 80] if split_anno: labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat( self.sem_seg_head.predictor.num_queries_test, 1).flatten(0, 1) else: labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # select 100 labels_per_image = labels[topk_indices] topk_indices = topk_indices // self.sem_seg_head.num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): # print(i, len(keep), lab) keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values() scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() # half mask box half pred box mask_box_result = mask_box_result[topk_indices] if self.panoptic_on: mask_box_result = mask_box_result[keep] result.pred_boxes = Boxes(mask_box_result) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob if self.sem_seg_postprocess_before_inference: mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) else: mask_scores_per_image = 1.0 # labels_per_image = labels_per_image + 1 # HACK for o365 classification if self.focus_on_box: mask_scores_per_image = 1.0 result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def box_postprocess(self, out_bbox, img_h, img_w): # postprocess box height and width boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) scale_fct = torch.tensor([img_w, img_h, img_w, img_h]) scale_fct = scale_fct.to(out_bbox) boxes = boxes * scale_fct return boxes @register_model def get_segmentation_model(cfg, **kwargs): return OpenSeeD(cfg) ================================================ FILE: llava/model/openseed/architectures/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/backbone/__init__.py ================================================ from .build import build_backbone from .focal import * from .focal_dw import * from .swin import * from .backbone import * ================================================ FILE: llava/model/openseed/backbone/backbone.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch.nn as nn from detectron2.modeling import ShapeSpec # from ..layers import ShapeSpec __all__ = ["Backbone"] class Backbone(nn.Module): """ Abstract base class for network backbones. """ def __init__(self): """ The `__init__` method of any subclass can specify its own set of arguments. """ super().__init__() def forward(self): """ Subclasses must override this method, but adhere to the same return type. Returns: dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor """ pass @property def size_divisibility(self) -> int: """ Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required. """ return 0 def output_shape(self): """ Returns: dict[str->ShapeSpec] """ # this is a backward-compatible default return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } ================================================ FILE: llava/model/openseed/backbone/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .backbone import * def build_backbone(config, **kwargs): model_name = config['MODEL']['BACKBONE']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/backbone/focal.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) if self.use_postln: x = self.norm1(x) # FFN x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 2; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/openseed/backbone/focal_dw.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw1(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw2(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) if not self.use_postln: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(x)) x = self.norm2(x) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, use_pre_norm=False, ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False, use_pre_norm=use_pre_norm ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W # class PatchEmbed(nn.Module): # r""" Image to Patch Embedding # Args: # img_size (int): Image size. Default: 224. # patch_size (int): Patch token size. Default: 4. # in_chans (int): Number of input image channels. Default: 3. # embed_dim (int): Number of linear projection output channels. Default: 96. # norm_layer (nn.Module, optional): Normalization layer. Default: None # """ # def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, # use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False): # super().__init__() # patch_size = to_2tuple(patch_size) # patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # self.img_size = img_size # self.patch_size = patch_size # self.patches_resolution = patches_resolution # self.num_patches = patches_resolution[0] * patches_resolution[1] # self.in_chans = in_chans # self.embed_dim = embed_dim # self.use_pre_norm = use_pre_norm # if use_conv_embed: # # if we choose to use conv embedding, then we treat the stem and non-stem differently # if is_stem: # kernel_size = 7; padding = 3; stride = 4 # else: # kernel_size = 3; padding = 1; stride = 2 # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) # else: # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # if self.use_pre_norm: # if norm_layer is not None: # self.norm = norm_layer(in_chans) # else: # self.norm = None # else: # if norm_layer is not None: # self.norm = norm_layer(embed_dim) # else: # self.norm = None # def forward(self, x): # B, C, H, W = x.shape # # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # if self.use_pre_norm: # if self.norm is not None: # x = x.flatten(2).transpose(1, 2) # B Ph*Pw C # x = self.norm(x).transpose(1, 2).view(B, C, H, W) # x = self.proj(x).flatten(2).transpose(1, 2) # else: # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C # if self.norm is not None: # x = self.norm(x) # return x # def flops(self): # Ho, Wo = self.patches_resolution # flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # if self.norm is not None: # flops += Ho * Wo * self.embed_dim # return flops class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.use_pre_norm = use_pre_norm if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 3; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if self.use_pre_norm: if norm_layer is not None: self.norm = norm_layer(in_chans) else: self.norm = None else: if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" B, C, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) if self.use_pre_norm: if self.norm is not None: x = x.flatten(2).transpose(1, 2) # B Ph*Pw C x = self.norm(x).transpose(1, 2).view(B, C, H, W) x = self.proj(x) else: x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_pre_norms=[False, False, False, False], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_pre_norm=use_pre_norms[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # self.norm = norm_layer(num_features[-1]) # add a norm layer for each output for i_layer in self.out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/openseed/backbone/registry.py ================================================ _model_entrypoints = {} def register_backbone(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/backbone/swin.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import Backbone, ShapeSpec from detectron2.utils.file_io import PathManager from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" # HACK model will not upsampling # if min([H, W]) <= self.window_size: # if window size is larger than input resolution, we don't partition windows # self.shift_size = 0 # self.window_size = min([H,W]) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ).type(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): relative_position_bias_table_pretrained = v relative_position_bias_table_current = model_dict[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): absolute_pos_embed_pretrained = v absolute_pos_embed_current = model_dict[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((1, L1, C1), (1, L2, C2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint): super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) self._out_features = cfg['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features)) return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in feature_names } @property def size_divisibility(self): return 32 @register_backbone def get_swin_backbone(cfg): swin_cfg = cfg['MODEL']['BACKBONE']['SWIN'] pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE'] patch_size = swin_cfg['PATCH_SIZE'] in_chans = 3 embed_dim = swin_cfg['EMBED_DIM'] depths = swin_cfg['DEPTHS'] num_heads = swin_cfg['NUM_HEADS'] window_size = swin_cfg['WINDOW_SIZE'] mlp_ratio = swin_cfg['MLP_RATIO'] qkv_bias = swin_cfg['QKV_BIAS'] qk_scale = swin_cfg['QK_SCALE'] drop_rate = swin_cfg['DROP_RATE'] attn_drop_rate = swin_cfg['ATTN_DROP_RATE'] drop_path_rate = swin_cfg['DROP_PATH_RATE'] norm_layer = nn.LayerNorm ape = swin_cfg['APE'] patch_norm = swin_cfg['PATCH_NORM'] use_checkpoint = swin_cfg['USE_CHECKPOINT'] out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3]) swin = D2SwinTransformer( swin_cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] with PathManager.open(filename, "rb") as f: ckpt = torch.load(f, map_location='cpu')['model'] swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return swin ================================================ FILE: llava/model/openseed/body/__init__.py ================================================ from .build import build_openseed_head ================================================ FILE: llava/model/openseed/body/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .openseed_head import * def build_openseed_head(config, *args, **kwargs): model_name = config['MODEL']['HEAD'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') body = model_entrypoints(model_name)(config, *args, **kwargs) return body ================================================ FILE: llava/model/openseed/body/decoder/__init__.py ================================================ from .build import build_decoder from .openseed_decoder import * from .openseed_decoder_decouple import * ================================================ FILE: llava/model/openseed/body/decoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_decoder(config, *args, **kwargs): model_name = config['MODEL']['DECODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/openseed/body/decoder/modules.py ================================================ from typing import Optional import torch from torch import nn, Tensor from torch.nn import functional as F from timm.models.layers import trunc_normal_ from detectron2.layers import Conv2d import fvcore.nn.weight_init as weight_init class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt, avg_attn def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) return tgt, avg_attn def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x ================================================ FILE: llava/model/openseed/body/decoder/openseed_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class OpenSeeDDecoder(nn.Module): @configurable def __init__( self, # lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.semantic_ce_loss = semantic_ce_loss # learnable query features if not two_stage or self.learn_tgt: self.query_feat = nn.Embedding(num_queries, hidden_dim) if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries, 4) if two_stage: self.enc_output = nn.Linear(hidden_dim, hidden_dim) self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj # self.lang_encoder = lang_encoder self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed self.logit_scale = nn.Parameter(torch.ones([])) self.default_text_embeddings = None #for grounding tokens self.default_text_embeddings_mask = None #for grounding tokens @classmethod def from_config(cls, cfg, in_channels, mask_classification, extra): ret = {} ret["in_channels"] = in_channels # ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] # Transformer parameters: ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = torch.gather(self.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def compute_similarity(self, v_emb,name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = self.default_text_embeddings output = self.logit_scale.exp() * v_emb @ t_emb.transpose(1, 2) output[~self.default_text_embeddings_mask[:,None].repeat(1,output.shape[1],1)] = -100. # output = v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg',default_text_embeddings=None, extra={}): """ task: seg/det """ self.default_text_embeddings,self.default_text_embeddings_mask=default_text_embeddings self.dn="no" assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] if self.two_stage: output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory_ = output_memory @ self.class_embed enc_outputs_class_unselected = self.compute_similarity(output_memory_,default_text_embeddings) enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # enc_outputs_class_unselected = self.class_embed(output_memory) enc_outputs_coord_unselected = self._bbox_embed( output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid refpoint_embed = refpoint_embed_undetach.detach() tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg) tgt = tgt_undetach.detach() if self.learn_tgt: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) interm_outputs=dict() interm_outputs['pred_logits'] = outputs_class interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() interm_outputs['pred_masks'] = outputs_mask if self.initialize_box_type != 'no' and do_seg: # convert masks into boxes to better initialize box in the decoder assert self.initial_pred flaten_mask = outputs_mask.detach().flatten(0, 1) h, w = outputs_mask.shape[-2:] if self.initialize_box_type == 'bitmask': # slower, but more accurate refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() elif self.initialize_box_type == 'mask2box': # faster conversion refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() else: assert NotImplementedError refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], dtype=torch.float).cuda() refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) refpoint_embed = inverse_sigmoid(refpoint_embed) elif not self.two_stage: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1) tgt_mask = None mask_dict = None if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0]) if mask_dict is not None: tgt=torch.cat([input_query_label, tgt],dim=1) # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) if self.dn != "no" and self.training and mask_dict is not None: refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1) # print('tgt',tgt.dtype) # print('src_flatten',src_flatten.dtype) # print('refpoint',refpoint_embed.dtype) tgt=tgt.to(src_flatten.dtype) refpoint_embed=refpoint_embed.to(src_flatten.dtype) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction if self.initial_pred: out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) assert len(predictions_class) == self.num_layers + 1 else: out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class =torch.stack(predictions_class) predictions_class, out_boxes,predictions_mask=\ self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask,out_boxes ) } if self.two_stage: out['interm_outputs'] = interm_outputs return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.compute_similarity(class_embed,self.default_text_embeddings) outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, mask_classification, extra): return OpenSeeDDecoder(cfg, in_channels, mask_classification, extra) ================================================ FILE: llava/model/openseed/body/decoder/openseed_decoder_decouple.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class MaskDINODecoder(nn.Module): @configurable def __init__( self, lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, no_update=False, num_queries_stuff=100, num_queries_test=300, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.num_queries_test=num_queries_test self.semantic_ce_loss = semantic_ce_loss self.no_update=no_update # learnable query features # if not two_stage or self.learn_tgt: self.num_queries_stuff=num_queries_stuff self.query_feat = nn.Embedding(num_queries_stuff, hidden_dim) # if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries_stuff, 4) if two_stage: self.enc_output = nn.Linear(hidden_dim, hidden_dim) self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj self.lang_encoder = lang_encoder self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed @classmethod def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): ret = {} ret["in_channels"] = in_channels ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] ret["num_queries_test"] = dec_cfg.get('NUM_OBJECT_QUERIES_TEST',300) # Transformer parameters: ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] ret["no_update"]=dec_cfg.get("no_update",False) return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size,task="other"): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if task=="cls": scalar=1 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) if task=="cls": known_labels_expaned=torch.zeros_like(known_labels_expaned) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() if task=="cls": input_label_embed=self.cls_emb(m) else: input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries+self.num_queries_stuff attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def forward_cls(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, extra={}): """ task: seg/det """ assert len(x) == self.num_feature_levels do_seg = False# if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx = self.num_feature_levels - 1 - i bs, c, h, w = x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] # if self.two_stage: # output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) # output_memory = self.enc_output_norm(self.enc_output(output_memory)) # output_memory_ = output_memory @ self.class_embed # enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_) # enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # # enc_outputs_class_unselected = self.class_embed(output_memory) # enc_outputs_coord_unselected = self._bbox_embed( # output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid # topk = self.num_queries # topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, # topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid # refpoint_embed = refpoint_embed_undetach.detach() # # tgt_undetach = torch.gather(output_memory, 1, # topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid # outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, # do_seg) # tgt = tgt_undetach.detach() # if self.learn_tgt: # tgt = self.query_feat.weight[None].repeat(bs, 1, 1) # interm_outputs = dict() # interm_outputs['pred_logits'] = outputs_class # interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() # interm_outputs['pred_masks'] = outputs_mask # # if self.initialize_box_type != 'no' and do_seg: # # convert masks into boxes to better initialize box in the decoder # assert self.initial_pred # flaten_mask = outputs_mask.detach().flatten(0, 1) # h, w = outputs_mask.shape[-2:] # if self.initialize_box_type == 'bitmask': # slower, but more accurate # refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() # elif self.initialize_box_type == 'mask2box': # faster conversion # refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() # else: # assert NotImplementedError # refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], # dtype=torch.float).cuda() # refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) # refpoint_embed = inverse_sigmoid(refpoint_embed) # elif not self.two_stage: # tgt = self.query_feat.weight[None].repeat(bs, 1, 1) # refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1) tgt_mask = None mask_dict = None # if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0],task="cls") # if mask_dict is not None: tgt = input_query_label refpoint_embed = input_query_bbox # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # if self.dn != "no" and self.training and mask_dict is not None: # tgt=tgt.float() hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=None, no_update=self.no_update, ) for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, ( self.training or (i == len(hs) - 1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction # if self.initial_pred: # out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) # assert len(predictions_class) == self.num_layers + 1 # else: # out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class = torch.stack(predictions_class) # predictions_class, out_boxes, predictions_mask = \ # self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss # if do_seg: # predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], # 'pred_boxes': out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask ) } # if self.two_stage: # out['interm_outputs'] = interm_outputs return out, None def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det """ assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] # if self.two_stage: output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory_ = output_memory @ self.class_embed enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_) enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # enc_outputs_class_unselected = self.class_embed(output_memory) enc_outputs_coord_unselected = self._bbox_embed( output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries if self.training else self.num_queries_test topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid tgt_stuff = self.query_feat.weight[None].repeat(bs, 1, 1) refpoint_embed_stuff = self.query_embed.weight[None].repeat(bs, 1, 1) # if not (self.) tgt_undetach=torch.cat([tgt_undetach,tgt_stuff],dim=1) refpoint_embed_undetach=torch.cat([refpoint_embed_undetach,refpoint_embed_stuff],dim=1) refpoint_embed = refpoint_embed_undetach.detach() outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg) tgt = tgt_undetach.detach() if self.learn_tgt: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) interm_outputs=dict() interm_outputs['pred_logits'] = outputs_class interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() interm_outputs['pred_masks'] = outputs_mask if self.initialize_box_type != 'no' and do_seg: # convert masks into boxes to better initialize box in the decoder assert self.initial_pred flaten_mask = outputs_mask.detach().flatten(0, 1) h, w = outputs_mask.shape[-2:] if self.initialize_box_type == 'bitmask': # slower, but more accurate refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() elif self.initialize_box_type == 'mask2box': # faster conversion refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() else: assert NotImplementedError refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], dtype=torch.float).cuda() refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) refpoint_embed = inverse_sigmoid(refpoint_embed) # elif not self.two_stage: tgt_mask = None mask_dict = None if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0]) if mask_dict is not None: tgt=torch.cat([input_query_label, tgt],dim=1) # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) if not (task == 'seg' or not self.training): outputs_class=outputs_class[:,:-self.num_queries_stuff] # outputs_mask=outputs_mask[:,:-self.num_queries_stuff] predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) if self.dn != "no" and self.training and mask_dict is not None: refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) if not (task=='seg' or not self.training): hs=[hs_[:,:-self.num_queries_stuff] for hs_ in hs] references=[references_[:,:-self.num_queries_stuff] for references_ in references] refpoint_embed=refpoint_embed[:,:-self.num_queries_stuff] for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction if self.initial_pred: out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) assert len(predictions_class) == self.num_layers + 1 else: out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class =torch.stack(predictions_class) predictions_class, out_boxes,predictions_mask=\ self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask,out_boxes ) } if self.two_stage: out['interm_outputs'] = interm_outputs return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.lang_encoder.compute_similarity(class_embed) outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: if outputs_seg_masks is None: return [ {"pred_logits": a} for a in outputs_class[:-1] ] else: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra) ================================================ FILE: llava/model/openseed/body/decoder/registry.py ================================================ _model_entrypoints = {} def register_decoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/decoder/utils/__init__.py ================================================ from .utils import * ================================================ FILE: llava/model/openseed/body/decoder/utils/dino_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Optional, List, Union import torch from torch import nn, Tensor from torch.cuda.amp import autocast from .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid from ...encoder.ops.modules import MSDeformAttn class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=True, num_feature_levels=1, deformable_decoder=True, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=True, dec_layer_share=False, dec_layer_dropout_prob=None, task_switch=None, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.task_switch = task_switch self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # misc extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys(): reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0) output = torch.cat((output, extra['grounding_tokens']), dim=0) for layer_id, layer in enumerate(self.layers): # preprocess ref points if self.training and self.decoder_query_perturber is not None and layer_id != 0: reference_points = self.decoder_query_perturber(reference_points) reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 # print('reference_points_input', reference_points_input.dtype) # print('memory', memory.dtype) reference_points_input=reference_points_input.to(memory.dtype) query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2 # print('query_sine_embed', query_sine_embed.dtype) raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, task_switch=self.task_switch, extra=extra, ) # grounding language token reference point will not update and saved if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _reference_points = reference_points[-extra['grounding_len']:] reference_points = reference_points[:-extra['grounding_len']] _output = output[-extra['grounding_len']:] output = output[:-extra['grounding_len']] # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) # add back grounding language token if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': reference_points = torch.cat((reference_points, _reference_points)) output = torch.cat((output, _output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, key_aware_type=None, ): super().__init__() # cross attention if use_deformable_box_attn: raise NotImplementedError else: self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt @autocast(enabled=False) def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # misc task_switch: Optional[Tensor] = {}, # extra information extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # exclude grounding token for cross attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _grounding_lang_tokens = tgt[-extra['grounding_len']:,] _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,] _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,] tgt = tgt[:-extra['grounding_len'],] tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],] tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],] # cross attention if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better tgt = tgt + self.dropout1(tgt2) # add back grounding token for self attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': tgt = torch.cat((tgt, _grounding_lang_tokens)) tgt = self.norm1(tgt) tgt = self.forward_ffn(tgt) # ffn return tgt ================================================ FILE: llava/model/openseed/body/decoder/utils/utils.py ================================================ import torch import copy from torch import nn, Tensor import os import math import torch.nn.functional as F from torch import nn class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) return output_memory, output_proposals def gen_sineembed_for_position(pos_tensor, dim=128): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * (dim_t // 2) / dim) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos.to(pos_tensor.dtype) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == "prelu": return nn.PReLU() if activation == "selu": return F.selu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) ================================================ FILE: llava/model/openseed/body/encoder/__init__.py ================================================ from .build import build_encoder ================================================ FILE: llava/model/openseed/body/encoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .transformer_encoder_fpn import * from .encoder_deform import * def build_encoder(config, *args, **kwargs): model_name = config['MODEL']['ENCODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/openseed/body/encoder/encoder_deform.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast from detectron2.layers import Conv2d, ShapeSpec, get_norm # from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_encoder from ...utils import configurable from ...modules import PositionEmbeddingSine from ..transformer_blocks import _get_clones, _get_activation_fn from .ops.modules import MSDeformAttn from torch.utils import checkpoint # MSDeformAttn Transformer encoder in deformable detr class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4,): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, use_ckpt=False): enable_mask=0 if masks is not None: for src in srcs: if src.size(2)%32 or src.size(3)%32: enable_mask = 1 if enable_mask==0: masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt) return memory, spatial_shapes, level_start_index class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False): output = src reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): if use_ckpt: output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) else: output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class OpenSeeDEncoder(nn.Module): """ This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models. """ @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, # deformable transformer encoder args transformer_in_features: List[str], common_stride: int, num_feature_levels: int, total_num_feature_levels: int, feature_order: str, use_ckpt=False, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers num_feature_levels: feature scales used total_num_feature_levels: total feautre scales used (include the downsampled features) feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first. """ super().__init__() self.use_ckpt = use_ckpt transformer_input_shape = { k: v for k, v in input_shape.items() if k in transformer_in_features } # this is the input shape of pixel decoder input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" self.feature_strides = [v.stride for k, v in input_shape] self.feature_channels = [v.channels for k, v in input_shape] self.feature_order = feature_order if feature_order == "low2high": transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride) else: transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" transformer_in_channels = [v.channels for k, v in transformer_input_shape] self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales self.total_num_feature_levels = total_num_feature_levels self.common_stride = common_stride self.transformer_num_feature_levels = len(self.transformer_in_features) self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels)) self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1 if self.transformer_num_feature_levels > 1: input_proj_list = [] for in_channels in transformer_in_channels[::-1]: input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )) # input projectino for downsample in_channels = max(transformer_in_channels) for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2 input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, conv_dim), )) in_channels = conv_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )]) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, num_feature_levels=self.total_num_feature_levels, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) self.mask_dim = mask_dim # use 1x1 conv instead self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=1, stride=1, padding=0, ) weight_init.c2_xavier_fill(self.mask_features) # extra fpn levels stride = min(self.transformer_feature_strides) self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1) lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] # deformable transformer encoder ret[ "transformer_enc_layers" ] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_in_features"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES'] # ['res3', 'res4', 'res5'] ret["common_stride"] = enc_cfg['COMMON_STRIDE'] ret["total_num_feature_levels"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_feature_levels"] = enc_cfg['NUM_FEATURE_LEVELS'] ret["feature_order"] = enc_cfg['FEATURE_ORDER'] ret["use_ckpt"] = enc_cfg.get('USE_CKPT', False) return ret @autocast(enabled=False) def forward_features(self, features, masks): """ :param features: multi-scale features from the backbone :param masks: image mask :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask """ # backbone features srcs = [] pos = [] # additional downsampled features srcsl = [] posl = [] if self.total_num_feature_levels > self.transformer_num_feature_levels: smallest_feat = features[self.transformer_in_features[self.low_resolution_index]] _len_srcs = self.transformer_num_feature_levels for l in range(_len_srcs, self.total_num_feature_levels): if l == _len_srcs: src = self.input_proj[l](smallest_feat) else: src = self.input_proj[l](srcsl[-1]) srcsl.append(src) posl.append(self.pe_layer(src)) srcsl = srcsl[::-1] # Reverse feature maps for idx, f in enumerate(self.transformer_in_features[::-1]): x = features[f] # deformable detr does not support half precision srcs.append(self.input_proj[idx](x)) pos.append(self.pe_layer(x)) srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs) pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos) if self.feature_order != 'low2high': srcs = srcsl pos = posl y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt) bs = y.shape[0] split_size_or_sections = [None] * self.total_num_feature_levels for i in range(self.total_num_feature_levels): if i < self.total_num_feature_levels - 1: split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] else: split_size_or_sections[i] = y.shape[1] - level_start_index[i] y = torch.split(y, split_size_or_sections, dim=1) out = [] multi_scale_features = [] num_cur_levels = 0 for i, z in enumerate(y): out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) convert = False if out[0].dtype == torch.bfloat16: out = [out_.float() for out_ in out] convert = True for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) if convert: y = y.bfloat16() y=cur_fpn + y y = output_conv(y) out.append(y) if convert: out = [out_.bfloat16() for out_ in out] for o in out: if num_cur_levels < self.total_num_feature_levels: multi_scale_features.append(o) num_cur_levels += 1 return self.mask_features(out[-1]), out[0], multi_scale_features @register_encoder def get_maskdino_encoder_deform(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = OpenSeeDEncoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/openseed/body/encoder/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: llava/model/openseed/body/encoder/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable try: import MultiScaleDeformableAttention as MSDA except ModuleNotFoundError as e: info_string = ( "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" "\t`cd mask2former/modeling/pixel_decoder/ops`\n" "\t`sh make.sh`\n" ) raise ModuleNotFoundError(info_string) class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: llava/model/openseed/body/encoder/ops/make.sh ================================================ #!/usr/bin/env bash # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR python setup.py build install --user ================================================ FILE: llava/model/openseed/body/encoder/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn import MSDeformAttn ================================================ FILE: llava/model/openseed/body/encoder/ops/modules/ms_deform_attn.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 128 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) # try: # print(value.dtype) convert=False if value.dtype== torch.bfloat16: value = value.float() attention_weights = attention_weights.float() sampling_locations = sampling_locations.float() convert=True output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) if convert: output = output.bfloat16() # except: # # CPU # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) # # For FLOPs calculation only # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output ================================================ FILE: llava/model/openseed/body/encoder/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] # Force cuda since torch ask for a device, not if cuda is in fact available. if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: if CUDA_HOME is None: raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') else: raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: llava/model/openseed/body/encoder/registry.py ================================================ _model_entrypoints = {} def register_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/encoder/transformer_encoder_fpn.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast import fvcore.nn.weight_init as weight_init from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm from .registry import register_encoder from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn from ...modules import PositionEmbeddingSine from ...utils import configurable # This is a modified FPN decoder. class BasePixelDecoder(nn.Module): def __init__( self, input_shape: Dict[str, ShapeSpec], *, conv_dim: int, mask_dim: int, mask_on: bool, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_channels = [v.channels for k, v in input_shape] lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(feature_channels): if idx == len(self.in_features) - 1: output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( in_channels, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(None) output_convs.append(output_conv) else: lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.mask_on = mask_on if self.mask_on: self.mask_dim = mask_dim self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=3, stride=1, padding=1, ) weight_init.c2_xavier_fill(self.mask_features) self.maskformer_num_feature_levels = 3 # always use 3 scales @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: y = output_conv(x) else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, None, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) class TransformerEncoderOnly(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) if mask is not None: mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) return memory.permute(1, 2, 0).view(bs, c, h, w) # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map. class TransformerEncoderPixelDecoder(BasePixelDecoder): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, transformer_pre_norm: bool, conv_dim: int, mask_dim: int, mask_on: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers transformer_pre_norm: whether to use pre-layernorm or not conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on) input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] in_channels = feature_channels[len(self.in_features) - 1] self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) self.transformer = TransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, normalize_before=transformer_pre_norm, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # update layer use_bias = norm == "" output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) delattr(self, "layer_{}".format(len(self.in_features))) self.add_module("layer_{}".format(len(self.in_features)), output_conv) self.output_convs[0] = output_conv @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = super().from_config(cfg, input_shape) ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_pre_norm"] = dec_cfg['PRE_NORM'] ret['mask_on'] = cfg['MODEL']['DECODER']['MASK'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: transformer = self.input_proj(x) pos = self.pe_layer(x) transformer = self.transformer(transformer, None, pos) y = output_conv(transformer) # save intermediate feature as input to Transformer decoder transformer_encoder_features = transformer else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, transformer_encoder_features, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) @register_encoder def get_transformer_encoder_fpn(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = TransformerEncoderPixelDecoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/openseed/body/openseed_head.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. # ------------------------------------------------------------------------------ import logging from typing import Callable, Dict, List, Optional, Tuple, Union from torch import nn from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_body from .encoder import build_encoder from .decoder import build_decoder from ..utils import configurable class OpenSeeDHead(nn.Module): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, transformer_predictor: nn.Module, ): """ Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. transformer_predictor: the transformer decoder that makes prediction transformer_in_feature: input feature name to the transformer_predictor """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = 4 self.loss_weight = loss_weight self.pixel_decoder = pixel_decoder self.predictor = transformer_predictor self.num_classes = num_classes @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] return { "input_shape": { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] }, "ignore_value": enc_cfg['IGNORE_VALUE'], "num_classes": enc_cfg.get('NUM_CLASSES', None), "pixel_decoder": build_encoder(cfg, input_shape), "loss_weight": enc_cfg['LOSS_WEIGHT'], "transformer_predictor": build_decoder( cfg, transformer_predictor_in_channels, mask_classification=True, extra=extra, ), } def forward(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, task='seg', extra={},default_text_embeddings=None): mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra,default_text_embeddings=default_text_embeddings) return predictions @register_body def get_maskdino_head(cfg, input_shape, lang_encoder, extra): return OpenSeeDHead(cfg, input_shape, lang_encoder, extra) ================================================ FILE: llava/model/openseed/body/registry.py ================================================ _model_entrypoints = {} def register_body(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/transformer_blocks.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py """ Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ import copy from typing import List, Optional import torch import torch.nn.functional as F from torch import Tensor, nn class Transformer(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, ) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) if mask is not None: mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder( tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed ) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos ) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") ================================================ FILE: llava/model/openseed/language/LangEncoder/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .build import build_lang_encoder from .build import build_tokenizer from .transformer import * ================================================ FILE: llava/model/openseed/language/LangEncoder/build.py ================================================ import os from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers import AutoTokenizer from .registry import lang_encoders from .registry import is_lang_encoder def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): model_name = config_encoder['NAME'] if not is_lang_encoder(model_name): raise ValueError(f'Unkown model: {model_name}') return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) def build_tokenizer(config_encoder): tokenizer = None os.environ['TOKENIZERS_PARALLELISM'] = 'true' if config_encoder['TOKENIZER'] == 'clip': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) elif config_encoder['TOKENIZER'] == 'clip-fast': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) else: tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) return tokenizer ================================================ FILE: llava/model/openseed/language/LangEncoder/registry.py ================================================ _lang_encoders = {} def register_lang_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _lang_encoders[model_name] = fn return fn def lang_encoders(model_name): return _lang_encoders[model_name] def is_lang_encoder(model_name): return model_name in _lang_encoders ================================================ FILE: llava/model/openseed/language/LangEncoder/transformer.py ================================================ from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .registry import register_lang_encoder from detectron2.utils.comm import is_main_process from utils.model import register_norm_module logger = logging.getLogger(__name__) @register_norm_module class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn( x, x, x, key_padding_mask=key_padding_mask, need_weights=False, attn_mask=self.attn_mask )[0] def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0, autogressive: bool =True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers self.autogressive = autogressive attn_mask = self.build_attention_mask() if autogressive else None dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.ModuleList( [ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) @property def dim_out(self): return self.width def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): if is_main_process(): logger.info('=> init weight of Linear/Conv2d from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: if is_main_process(): logger.info('=> init bias of Linear/Conv2d to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x pretrained_dict = { stripped_key(k): v for k, v in pretrained_dict.items() if stripped_key(k) in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if verbose: logger.info(f'=> init {k} from {pretrained}') if 'positional_embedding' in k and v.size() != model_dict[k].size(): positional_embedding_pretrained = v positional_embedding_current = model_dict[k] L1, nH1 = positional_embedding_pretrained.size() L2, nH2 = positional_embedding_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) posemb = positional_embedding_pretrained.float() posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) v = posemb_grid need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, input_ids, attention_mask=None): key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None # key_padding_mask = (input_ids == 0) if not self.autogressive else None x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND for block in self.resblocks: x = block(x, key_padding_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) return {'last_hidden_state': x} @register_lang_encoder def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): transformer = Transformer( context_length=config_encoder['CONTEXT_LENGTH'], vocab_size=tokenizer.vocab_size, width=config_encoder['WIDTH'], layers=config_encoder['LAYERS'], heads=config_encoder['HEADS'], autogressive=config_encoder.get('AUTOGRESSIVE', True) ) if config_encoder.get('LOAD_PRETRAINED', False): transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*'])) return transformer ================================================ FILE: llava/model/openseed/language/__init__.py ================================================ # from .vlpencoder import * # from .encoder import * # # from .loss import * # from .build import build_language_encoder ================================================ FILE: llava/model/openseed/language/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_language_encoder(config, **kwargs): model_name = config['MODEL']['TEXT']['ARCH'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/language/encoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } # @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) # @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/openseed/language/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/language/vlpencoder.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, queue_operator, ): super().__init__() # seg self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) # captioning & retrieval for key, value in queue_operator.items(): self.register_buffer(key, value) @classmethod def from_config(cls, cfg): # build up text encoder for seg tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) # tested not working better queue_operator = {} return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, "queue_operator": queue_operator, } def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] if prompt: for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) else: clss_embeddings.append(extract_mean_emb(class_names)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) def get_text_token_embeddings(self, txts, name='default', token=False, norm=False): if not token: tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens = {key: value.cuda() for key, value in tokens.items()} else: tokens = txts token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm) ret = {"tokens": tokens, "token_emb": token_emb, "class_emb": class_emb,} setattr(self, '{}_token_embeddings'.format(name), ret) return ret def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def forward_language_token(self, texts, norm=False): x = self.lang_encoder(*texts) token_x = x['last_hidden_state'] if self.tokenizer_type == 'clip': class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)] else: class_x = token_x[:, 0] class_x = class_x @ self.lang_proj token_x = token_x @ self.lang_proj if norm: class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7) token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7) return token_x, class_x def compute_similarity(self, v_emb, name='default', fake=False): if fake: return None v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/openseed/modules/__init__.py ================================================ from .point_features import * from .position_encoding import * from .postprocessing import * from .attention import * from .matcher import * from .criterion import * ================================================ FILE: llava/model/openseed/modules/attention.py ================================================ import warnings from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter from torch.overrides import has_torch_function, handle_torch_function from torch.nn.functional import pad, linear, softmax, dropout def multi_head_attention_forward( query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None, k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, tens_ops, query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 if not use_separate_proj_weight: if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)): # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) if key is None: assert value is None k = None v = None else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] k, v = linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] k = linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] v = linear(value, _w, _b) else: q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) len1, len2 = q_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == query.size(-1) k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) len1, len2 = k_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == key.size(-1) v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) len1, len2 = v_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == value.size(-1) if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) v = linear(value, v_proj_weight_non_opt, in_proj_bias) q = q * scaling if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert bias_k is None assert bias_v is None q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: # assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if add_zero_attn: src_len += 1 k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1), float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(nn.Linear): bias: Tensor # type: ignore def __init__(self, in_features: int, out_features: int) -> None: super().__init__(in_features, out_features, bias=True) # type: ignore class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = _LinearWithBias(embed_dim, embed_dim) if add_bias_kv: self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(MultiheadAttention, self).__setstate__(state) def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not self._qkv_same_embed_dim: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) else: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask) ================================================ FILE: llava/model/openseed/modules/criterion.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ """ MaskFormer criterion. """ import logging import torch import torch.nn.functional as F from torch import nn from timm.loss import SoftTargetCrossEntropy from detectron2.utils.comm import get_world_size from detectron2.projects.point_rend.point_features import ( get_uncertain_point_coords_with_randomness, point_sample, ) from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis from ..utils import box_ops import random def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(-1).mean(-1).sum()*10. / num_boxes def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") return loss.mean(1).sum() / num_masks sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) class SetCriterion(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses, num_points, oversample_ratio, importance_sample_ratio, grounding_weight, dn="no",dn_losses=[], panoptic_on=False, semantic_ce_loss=False): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.top_x_layers = top_x_layers self.losses = losses self.dn = dn self.dn_losses = dn_losses empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = num_points self.oversample_ratio = oversample_ratio self.importance_sample_ratio = importance_sample_ratio self.focal_alpha = 0.25 self.panoptic_on = panoptic_on self.semantic_ce_loss = semantic_ce_loss self.grounding_weight = grounding_weight self.conversation=False def loss_labels_ce(self, outputs, targets, indices, num_masks, layer_id=None, extra=None): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert "pred_logits" in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs["pred_logits"].type(self.empty_weight.dtype) idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_mask_cls_0": loss_ce} return losses def loss_labels_masked(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert 'pred_logits' in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits[idx], target_classes_onehot[idx], num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_mask_cls_0': loss_ce} return losses def loss_labels(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert 'pred_logits' in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) # target_classes = torch.full(src_logits.shape[:2], self.num_classes, # dtype=torch.int64, device=src_logits.device) # target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) for batch_id,indices_ in enumerate(indices): for src,tgt in zip(*indices_): gt_lbs=targets[batch_id]['labels'][tgt] target_classes_onehot[batch_id,src,gt_lbs]=1 # target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) # target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_mask_cls_0': loss_ce} return losses def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ if layer_id >= self.top_x_layers['box']: return {"loss_bbox_0": 0, "loss_giou_0": 0} assert 'pred_boxes' in outputs if indices is None or len(targets) == 0: loss = outputs['pred_boxes'].sum() * 0.0 losses = {"loss_bbox_0": loss, "loss_giou_0": loss} return losses idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou_0'] = loss_giou.sum() / num_boxes return losses def loss_boxes_panoptic(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ if layer_id >= self.top_x_layers['box']: return {"loss_bbox_0": 0, "loss_giou_0": 0} assert 'pred_boxes' in outputs if indices is None or len(targets) == 0: loss = outputs['pred_boxes'].sum() * 0.0 losses = {"loss_bbox_0": loss, "loss_giou_0": loss} return losses idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_labels = torch.cat([t['labels'][i] for t, (_, i) in zip(targets, indices)], dim=0) isthing=target_labels<80 target_boxes=target_boxes[isthing] src_boxes=src_boxes[isthing] loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou_0'] = loss_giou.sum() / num_boxes return losses def loss_masks(self, outputs, targets, indices, num_masks, layer_id=None, extra=None): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ if layer_id >= self.top_x_layers['mask']: return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0} assert "pred_masks" in outputs if indices is None or len(targets) == 0: loss = outputs['pred_masks'].sum() * 0.0 losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss} return losses src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] src_masks = src_masks[src_idx] masks = [t["masks"] for t in targets] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(src_masks) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates :) # N x 1 x H x W src_masks = src_masks[:, None] target_masks = target_masks[:, None] with torch.no_grad(): # sample point_coords point_coords = get_uncertain_point_coords_with_randomness( src_masks.float(), lambda logits: calculate_uncertainty(logits.float()), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) # get gt labels point_labels = point_sample( target_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) point_logits = point_sample( src_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) losses = { "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks), } del src_masks del target_masks return losses def prep_for_dn(self,mask_dict): output_known_lbs_bboxes = mask_dict['output_known_lbs_bboxes'] known_indice = mask_dict['known_indice'] scalar,pad_size=mask_dict['scalar'],mask_dict['pad_size'] assert pad_size % scalar==0 single_pad=pad_size//scalar num_tgt = known_indice.numel() return output_known_lbs_bboxes,num_tgt,single_pad,scalar def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_masks=None, layer_id=None, extra=None): loss_map = { 'labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels, 'labels_dn': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked, 'dn_labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked, 'masks': self.loss_masks, 'boxes': self.loss_boxes_panoptic if self.panoptic_on else self.loss_boxes, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_masks, layer_id=layer_id, extra=extra) def forward(self, outputs, targets, mask_dict=None, extra=None, task='seg'): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # TODO: use different matching and loss weight when only detection outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} match_cost = ["cls", "box", "mask"] if task == 'det' or task == 'seg_from_teacher': match_cost = ["cls", "box"] # Retrieve the matching between the outputs of the last layer and the targets if self.dn != "no" and mask_dict is not None: output_known_lbs_bboxes,num_tgt,single_pad,scalar = self.prep_for_dn(mask_dict) exc_idx = [] for i in range(len(targets)): if len(targets[i]['labels']) > 0: t = torch.arange(0, len(targets[i]['labels'])).long().cuda() t = t.unsqueeze(0).repeat(scalar, 1) tgt_idx = t.flatten() output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t output_idx = output_idx.flatten() else: output_idx = tgt_idx = torch.tensor([]).long().cuda() exc_idx.append((output_idx, tgt_idx)) extra=dict() # if task == "seg": # extra['split_pano']={'n_q_th':300} # # else: # extra['split_pano'] = None indices = self.matcher(outputs_without_aux, targets, match_cost, extra=extra) # Compute the average number of target boxes accross all nodes, for normalization purposes num_masks = sum(len(t["labels"]) for t in targets) num_masks = torch.as_tensor( [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device ) if is_dist_avail_and_initialized() and not self.conversation: torch.distributed.all_reduce(num_masks) num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() else: num_masks = torch.clamp(num_masks, min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, layer_id=0, extra=extra)) if self.dn != "no" and mask_dict is not None: l_dict={} for loss in self.dn_losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only if loss == 'labels': loss='labels_dn' l_dict.update(self.get_loss(loss, output_known_lbs_bboxes, targets, exc_idx, num_masks*scalar, layer_id=0)) l_dict = {k + f'_dn': v for k, v in l_dict.items()} losses.update(l_dict) elif self.dn != "no": l_dict = dict() l_dict['loss_bbox_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_giou_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_cls_0_dn'] = torch.as_tensor(0.).to('cuda') if task != 'det' and 'masks' in self.dn_losses: l_dict['loss_mask_bce_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_dice_0_dn'] = torch.as_tensor(0.).to('cuda') losses.update(l_dict) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets, match_cost, extra=extra) for loss in self.losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, layer_id=(i+1), extra=extra) l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()} losses.update(l_dict) if 'interm_outputs' in outputs: start = 0 else: start = 1 if i>=start: if self.dn != "no" and mask_dict is not None: out_=output_known_lbs_bboxes['aux_outputs'][i] l_dict = {} for loss in self.dn_losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only if loss == 'labels': loss = 'labels_dn' l_dict.update( self.get_loss(loss, out_, targets, exc_idx, num_masks * scalar, layer_id=(i+1), extra=extra)) l_dict = {k.replace('_0', f"_{i+1}_dn"): v for k, v in l_dict.items()} losses.update(l_dict) elif self.dn != "no": l_dict = dict() l_dict[f'loss_bbox_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_giou_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_mask_cls_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') if self.dn == "seg" and task != 'det' and 'masks' in self.dn_losses: l_dict[f'loss_mask_bce_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_mask_dice_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') losses.update(l_dict) # interm_outputs loss if 'interm_outputs' in outputs: interm_outputs = outputs['interm_outputs'] indices = self.matcher(interm_outputs, targets, match_cost, extra=extra) full_set = ['labels', 'masks', 'boxes'] for loss in list(set(self.losses) and set(full_set)): if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_masks, layer_id=-1, extra=extra) l_dict = {k + f'_interm': v for k, v in l_dict.items()} losses.update(l_dict) return losses def __repr__(self): head = "Criterion " + self.__class__.__name__ body = [ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), "losses: {}".format(self.losses), "weight_dict: {}".format(self.weight_dict), "num_classes: {}".format(self.num_classes), "eos_coef: {}".format(self.eos_coef), "num_points: {}".format(self.num_points), "oversample_ratio: {}".format(self.oversample_ratio), "importance_sample_ratio: {}".format(self.importance_sample_ratio), ] _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/matcher.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ """ Modules to compute the matching cost and solve the corresponding LSAP. """ import torch import torch.nn.functional as F import numpy as np from scipy.optimize import linear_sum_assignment from torch import nn from torch.cuda.amp import autocast from detectron2.projects.point_rend.point_features import point_sample from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss batch_dice_loss_jit = torch.jit.script( batch_dice_loss ) # type: torch.jit.ScriptModule def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), reduction="none" ) loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( "nc,mc->nm", neg, (1 - targets) ) return loss / hw batch_sigmoid_ce_loss_jit = torch.jit.script( batch_sigmoid_ce_loss ) # type: torch.jit.ScriptModule class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, cost_box: float = 0, cost_giou: float = 0, panoptic_on: bool = False): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice self.cost_box = cost_box self.cost_giou = cost_giou self.panoptic_on = panoptic_on assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" self.num_points = num_points @torch.no_grad() def memory_efficient_forward(self, outputs, targets, cost=["cls", "box", "mask"],split_pano=None): """More memory-friendly matching. Change cost to compute only certain loss in matching""" bs, num_queries = outputs["pred_logits"].shape[:2] indices = [] # Iterate through batch size for b in range(bs): out_bbox = outputs["pred_boxes"][b].float() if 'box' in cost: tgt_bbox=targets[b]["boxes"] cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) else: cost_bbox = torch.tensor(0).to(out_bbox) cost_giou = torch.tensor(0).to(out_bbox) out_prob = outputs["pred_logits"][b].sigmoid().float() # [num_queries, num_classes] tgt_ids = targets[b]["labels"] cost_class=torch.zeros_like(cost_bbox) alpha = 0.25 gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-6).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-6).log()) for idx, tgt_ids_ in enumerate(tgt_ids): if len(tgt_ids_) == 0: continue cost_class_tmp = pos_cost_class[:, tgt_ids_] - neg_cost_class[:, tgt_ids_] cost_class_tmp = cost_class_tmp.mean(dim=1, keepdim=False) cost_class[:, idx] = cost_class_tmp # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. # cost_class = -out_prob[:, tgt_ids] if 'mask' in cost: out_mask = outputs["pred_masks"][b].float() # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["masks"].to(out_mask).float() out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype) # get gt labels tgt_mask = point_sample( tgt_mask.float(), point_coords.repeat(tgt_mask.shape[0], 1, 1).float(), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask.float(), point_coords.repeat(out_mask.shape[0], 1, 1).float(), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) else: cost_mask = torch.tensor(0).to(out_bbox) cost_dice = torch.tensor(0).to(out_bbox) # Final cost matrix if self.panoptic_on: isthing = tgt_ids<80 cost_bbox[:, ~isthing] = cost_bbox[:, isthing].mean() cost_giou[:, ~isthing] = cost_giou[:, isthing].mean() cost_bbox[cost_bbox.isnan()] = 0.0 cost_giou[cost_giou.isnan()] = 0.0 C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + self.cost_box*cost_bbox + self.cost_giou*cost_giou ) C = C.reshape(num_queries, -1).cpu() # if split_pano is not None: # n_q_th=split_pano['n_q_th'] # th_mask=tgt_ids<80 # There are 80 COCO thing classes (should be modified when trained with other panoptic datasets) # C[n_q_th:,th_mask]=1e4 # C[:n_q_th,~th_mask]=1e4 indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def forward(self, outputs, targets, cost=["cls", "box", "mask"], mode='default', extra={}): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ if mode == 'default': if extra is not None: split_pano = extra.get('split_pano', None) else: split_pano=None return self.memory_efficient_forward(outputs, targets, cost,split_pano=split_pano) else: assert False, "Mode {} is not supported.".format(mode) def __repr__(self, _repr_indent=4): head = "Matcher " + self.__class__.__name__ body = [ "cost_class: {}".format(self.cost_class), "cost_mask: {}".format(self.cost_mask), "cost_dice: {}".format(self.cost_dice), ] lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/point_features.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.layers import cat, shapes_to_tensor from detectron2.structures import BitMasks, Boxes # from ..layers import cat, shapes_to_tensor # from ..structures import BitMasks, Boxes """ Shape shorthand in this module: N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the number of images for semantic segmenation. R: number of ROIs, combined over all images, in the minibatch P: number of points """ def point_sample(input, point_coords, **kwargs): """ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside [0, 1] x [0, 1] square. Args: input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains [0, 1] x [0, 1] normalized point coordinates. Returns: output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains features for points in `point_coords`. The features are obtained via bilinear interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. """ add_dim = False if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output def generate_regular_grid_point_coords(R, side_size, device): """ Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. Args: R (int): The number of grids to sample, one for each region. side_size (int): The side size of the regular grid. device (torch.device): Desired device of returned tensor. Returns: (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates for the regular grids. """ aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) return r.view(1, -1, 2).expand(R, -1, -1) def get_uncertain_point_coords_with_randomness( coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio ): """ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. See PointRend paper for details. Args: coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for class-specific or class-agnostic prediction. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that contains logit predictions for P points and returns their uncertainties as a Tensor of shape (N, 1, P). num_points (int): The number of points P to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. """ assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype) point_logits = point_sample(coarse_logits, point_coords, align_corners=False) # It is crucial to calculate uncertainty based on the sampled prediction value for the points. # Calculating uncertainties of the coarse predictions first and sampling them for points leads # to incorrect results. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. # However, if we calculate uncertainties for the coarse predictions first, # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( num_boxes, num_uncertain_points, 2 ) if num_random_points > 0: point_coords = cat( [ point_coords, torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty values for a set of points on a regular H x W grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W grid. """ R, _, H, W = uncertainty_map.shape h_step = 1.0 / float(H) w_step = 1.0 / float(W) num_points = min(H * W, num_points) point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step return point_indices, point_coords def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): """ Get features from feature maps in `features_list` that correspond to specific point coordinates inside each bounding box from `boxes`. Args: features_list (list[Tensor]): A list of feature map tensors to get features from. feature_scales (list[float]): A list of scales for tensors in `features_list`. boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all together. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled from all features maps in feature_list for P sampled points for all R boxes in `boxes`. point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level coordinates of P points. """ cat_boxes = Boxes.cat(boxes) num_boxes = [b.tensor.size(0) for b in boxes] point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) point_features = [] for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): point_features_per_image = [] for idx_feature, feature_map in enumerate(features_list): h, w = feature_map.shape[-2:] scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature] point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device) point_features_per_image.append( point_sample( feature_map[idx_img].unsqueeze(0), point_coords_scaled.unsqueeze(0), align_corners=False, ) .squeeze(0) .transpose(1, 0) ) point_features.append(cat(point_features_per_image, dim=1)) return cat(point_features, dim=0), point_coords_wrt_image def get_point_coords_wrt_image(boxes_coords, point_coords): """ Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. Args: boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. coordinates. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-normalized coordinates of P sampled points. """ with torch.no_grad(): point_coords_wrt_image = point_coords.clone() point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( boxes_coords[:, None, 2] - boxes_coords[:, None, 0] ) point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( boxes_coords[:, None, 3] - boxes_coords[:, None, 1] ) point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] return point_coords_wrt_image def sample_point_labels(instances, point_coords): """ Sample point labels from ground truth mask given point_coords. Args: instances (list[Instances]): A list of N Instances, where N is the number of images in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. The ground-truth gt_masks in each instance will be used to compute labels. points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of instances and P is the number of points for each instance. The coordinates are in the absolute image pixel coordinate space, i.e. [0, H] x [0, W]. Returns: Tensor: A tensor of shape (R, P) that contains the labels of P sampled points. """ with torch.no_grad(): gt_mask_logits = [] point_coords_splits = torch.split( point_coords, [len(instances_per_image) for instances_per_image in instances] ) for i, instances_per_image in enumerate(instances): if len(instances_per_image) == 0: continue assert isinstance( instances_per_image.gt_masks, BitMasks ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." gt_bit_masks = instances_per_image.gt_masks.tensor h, w = instances_per_image.gt_masks.image_size scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) points_coord_grid_sample_format = point_coords_splits[i] / scale gt_mask_logits.append( point_sample( gt_bit_masks.to(torch.float32).unsqueeze(1), points_coord_grid_sample_format, align_corners=False, ).squeeze(1) ) point_labels = cat(gt_mask_logits) return point_labels ================================================ FILE: llava/model/openseed/modules/position_encoding.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py """ Various positional encodings for the transformer. """ import math import torch from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x, mask=None): if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=x.dtype) x_embed = not_mask.cumsum(2, dtype=x.dtype) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/postprocessing.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.structures import Instances, ROIMasks # perhaps should rename to "resize_instance" def detector_postprocess( results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 ): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ if isinstance(output_width, torch.Tensor): # This shape might (but not necessarily) be tensors during tracing. # Converts integer tensors to float temporaries to ensure true # division is performed when computing scale_x and scale_y. output_width_tmp = output_width.float() output_height_tmp = output_height.float() new_size = torch.stack([output_height, output_width]) else: new_size = (output_height, output_width) output_width_tmp = output_width output_height_tmp = output_height scale_x, scale_y = ( output_width_tmp / results.image_size[1], output_height_tmp / results.image_size[0], ) results = Instances(new_size, **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes else: output_boxes = None assert output_boxes is not None, "Predictions must contain boxes!" output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): if isinstance(results.pred_masks, ROIMasks): roi_masks = results.pred_masks else: # pred_masks is a tensor of shape (N, 1, M, M) roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) results.pred_masks = roi_masks.to_bitmasks( results.pred_boxes, output_height, output_width, mask_threshold ).tensor # TODO return ROIMasks/BitMask object in the future if results.has("pred_keypoints"): results.pred_keypoints[:, :, 0] *= scale_x results.pred_keypoints[:, :, 1] *= scale_y return results def bbox_postprocess(result, input_size, img_size, output_height, output_width): """ result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] """ if result is None: return None scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) result = result.sigmoid() * scale x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 h,w = img_size x1 = x1.clamp(min=0, max=w) y1 = y1.clamp(min=0, max=h) x2 = x2.clamp(min=0, max=w) y2 = y2.clamp(min=0, max=h) box = torch.stack([x1,y1,x2,y2]).permute(1,0) scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) box = box*scale return box def sem_seg_postprocess(result, img_size, output_height, output_width): """ Return semantic segmentation predictions in the original resolution. The input images are often resized when entering semantic segmentor. Moreover, in same cases, they also padded inside segmentor to be divisible by maximum network stride. As a result, we often need the predictions of the segmentor in a different resolution from its inputs. Args: result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), where C is the number of classes, and H, W are the height and width of the prediction. img_size (tuple): image size that segmentor is taking as input. output_height, output_width: the desired output resolution. Returns: semantic segmentation prediction (Tensor): A tensor of the shape (C, output_height, output_width) that contains per-pixel soft predictions. """ result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) result = F.interpolate( result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True )[0] return result ================================================ FILE: llava/model/openseed/utils/__init__.py ================================================ from .config import * from .misc import * ================================================ FILE: llava/model/openseed/utils/box_ops.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for bounding box manipulation and GIoU. """ import torch from torchvision.ops.boxes import box_area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) def box_xywh_to_xyxy(x): x0, y0, x1, y1 = x.unbind(-1) b = [x0, y0, (x0 + x1), (y0 + y1)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / (union+1e-6) return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / (area+1e-6) def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) ================================================ FILE: llava/model/openseed/utils/config.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import functools import inspect def configurable(init_func=None, *, from_config=None): """ Decorate a function or a class's __init__ method so that it can be called with a :class:`CfgNode` object using a :func:`from_config` function that translates :class:`CfgNode` to arguments. Examples: :: # Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite Args: init_func (callable): a class's ``__init__`` method in usage 1. The class must have a ``from_config`` classmethod which takes `cfg` as the first argument. from_config (callable): the from_config function in usage 2. It must take `cfg` as its first argument. """ if init_func is not None: assert ( inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__" ), "Incorrect use of @configurable. Check API documentation for examples." @functools.wraps(init_func) def wrapped(self, *args, **kwargs): try: from_config_func = type(self).from_config except AttributeError as e: raise AttributeError( "Class with @configurable must have a 'from_config' classmethod." ) from e if not inspect.ismethod(from_config_func): raise TypeError("Class with @configurable must have a 'from_config' classmethod.") if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) init_func(self, **explicit_args) else: init_func(self, *args, **kwargs) return wrapped else: if from_config is None: return configurable # @configurable() is made equivalent to @configurable assert inspect.isfunction( from_config ), "from_config argument of configurable must be a function!" def wrapper(orig_func): @functools.wraps(orig_func) def wrapped(*args, **kwargs): if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config, *args, **kwargs) return orig_func(**explicit_args) else: return orig_func(*args, **kwargs) wrapped.from_config = from_config return wrapped return wrapper def _called_with_cfg(*args, **kwargs): """ Returns: bool: whether the arguments contain CfgNode and should be considered forwarded to from_config. """ from omegaconf import DictConfig, OmegaConf, ListConfig # from detectron2.config import LazyConfig if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))): return True if isinstance(kwargs.pop("cfg", None), (dict)): return True # `from_config`'s first argument is forced to be "cfg". # So the above check covers all cases. return False def _get_args_from_config(from_config_func, *args, **kwargs): """ Use `from_config` to obtain explicit arguments. Returns: dict: arguments to be used for cls.__init__ """ signature = inspect.signature(from_config_func) if list(signature.parameters.keys())[0] != "cfg": if inspect.isfunction(from_config_func): name = from_config_func.__name__ else: name = f"{from_config_func.__self__}.from_config" raise TypeError(f"{name} must take 'cfg' as the first argument!") support_var_arg = any( param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] for param in signature.parameters.values() ) if support_var_arg: # forward all arguments to from_config, if from_config accepts them ret = from_config_func(*args, **kwargs) else: # forward supported arguments to from_config supported_arg_names = set(signature.parameters.keys()) extra_kwargs = {} for name in list(kwargs.keys()): if name not in supported_arg_names: extra_kwargs[name] = kwargs.pop(name) ret = from_config_func(*args, **kwargs) # forward the other arguments to __init__ ret.update(extra_kwargs) return ret ================================================ FILE: llava/model/openseed/utils/misc.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ from typing import List, Optional import torch import torch.distributed as dist import torchvision from torch import Tensor # from utils.constants import * def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False elif tensor_list[0].ndim == 2: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(txt.shape) for txt in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, l = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, l), dtype=torch.bool, device=device) for txt, pad_txt, m in zip(tensor_list, tensor, mask): pad_txt[: txt.shape[0], : txt.shape[1]] = txt m[: txt.shape[1]] = False else: raise ValueError("not supported") return NestedTensor(tensor, mask) def _collate_and_pad_divisibility(tensor_list: list, div=32): max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) c,h,w = max_size pad_h = (div - h % div) if h % div != 0 else 0 pad_w = (div - w % div) if w % div != 0 else 0 max_size = (c,h+pad_h,w+pad_w) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) return padded_imgs # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True # def get_class_names(name, background=True): # if name is None: # return None # if 'refcoco' in name: # class_names = ["noun"] # elif 'coco' in name and 'pan' not in name: # class_names = COCO_INSTANCE_CLASSES + ["background"] # elif 'coco' in name: # class_names = COCO_PANOPTIC_CLASSES + ["background"] # elif 'ade20k_full' in name: # class_names = ADE20K_847 + ["background"] # elif 'ade' in name: # class_names = ADE_PANOPTIC_CLASSES + ["background"] # elif 'voc' in name: # class_names = PASCAL_CLASSES + ["background"] # elif 'vlp' in name: # class_names = ["noun"] # elif 'tsv' in name: # class_names = ["noun"] # elif 'phrasecut' in name: # class_names = ["noun"] # elif 'openimage' in name: # class_names = ["noun"] # elif 'imagenet' in name: # class_names = IMAGENET_CLASSES # elif 'context_459' in name: # class_names = PASCAL_CONTEXT_459 + ["background"] # elif 'context_59' in name: # class_names = PASCAL_CONTEXT_59 + ["background"] # elif 'context_33' in name: # class_names = PASCAL_CONTEXT_33 # elif 'sunrgbd_37' in name: # class_names = SUN_RGBD_37 # elif 'scannet_41' in name: # class_names = SCAN_40 # elif 'scannet_38' in name: # class_names = SCAN_37 # elif 'scannet_21' in name: # class_names = SCAN_20 # elif 'object365' in name: # class_names = OBJECT365 # elif 'lvis' in name: # class_names = LVIS_CATEGORIES # elif 'seginw' in name: # class_names = SEGINW_CATEGORIES[name.replace('_train', '').replace('_val', '')] + ["background"] # elif name == 'cityscapes_fine_sem_seg_val': # class_names = CITYSCAPES # elif name == 'cityscapes_fine_instance_seg_val': # class_names = CITYSCAPES_THING + ["background"] # elif name in ['cityscapes_fine_panoptic_val', 'cityscapes_fine_panoptic_train']: # class_names = CITYSCAPES + ["background"] # elif name == 'bdd10k_val_sem_seg': # class_names = BDD_SEM # elif name == 'bdd10k_40_panoptic_val': # class_names = BDD_PANO # else: # assert False, "text dataset name {} is not defined".format(name) # # if background == False and "background" in class_names: # class_names.pop(class_names.index("background")) # # return class_names # TODO: add background to # def get_class_names(name): # if name is None: # return None # elif 'refcoco' in name: # return ["background"] # elif 'coco' in name: # return COCO_PANOPTIC_CLASSES + ["background"] # elif 'ade20k_full' in name: # return ADE20K_847 + ["background"] # elif 'ade' in name: # return ADE_PANOPTIC_CLASSES + ["background"] # elif 'scannet_41' in name: # return SCAN_40 # elif 'scannet_21' in name: # return SCAN_20 # elif 'sun' in name: # return SUN_RGBD_37 # elif name == 'cityscapes_fine_sem_seg_val': # return CITYSCAPES + ["background"] # elif name == 'cityscapes_fine_instance_seg_val': # return CITYSCAPES_THING + ["background"] # elif name in ['cityscapes_fine_panoptic_val']: # return CITYSCAPES + ["background"] # elif name == 'bdd10k_val_sem_seg': # return BDD_SEM + ["background"] # elif name == 'bdd10k_40_panoptic_val': # return BDD_PANO + ["background"] # elif 'vlp' in name: # return ["background"] # else: # assert False, "text dataset name {} is not defined".format(name) ================================================ FILE: llava/model/semsam/BaseModel.py ================================================ import os import logging import torch import torch.nn as nn from utils.model import align_and_update_state_dicts logger = logging.getLogger(__name__) class BaseModel(nn.Module): def __init__(self, opt, module: nn.Module): super(BaseModel, self).__init__() self.opt = opt self.model = module def forward(self, *inputs, **kwargs): outputs = self.model(*inputs, **kwargs) return outputs def save_pretrained(self, save_dir): torch.save(self.model.state_dict(), save_path) def from_pretrained(self, load_dir): state_dict = torch.load(load_dir, map_location='cpu') # import pdb;pdb.set_trace() # import pdb;pdb.set_trace() if 'model' in state_dict: state_dict=state_dict['model'] state_dict={k[6:]:v for k,v in state_dict.items()} # if self.opt['MODEL']['LLAMA'].get('lora_r',0)>0: # new_sd = dict() # for k,v in state_dict.items(): # if k.startswith("llama."): # if k.startswith("llama.base_model."): # new_sd=state_dict # break # new_sd[k.replace("llama.","llama.base_model.model.")]=v # else: # new_sd[k]=v # else: # new_sd = state_dict new_sd = align_and_update_state_dicts(self.model.state_dict(), state_dict) self.model.load_state_dict(new_sd, strict=False) return self ================================================ FILE: llava/model/semsam/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .architectures import build_model ================================================ FILE: llava/model/semsam/architectures/__init__.py ================================================ from .idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1 import * from .build import build_model ================================================ FILE: llava/model/semsam/architectures/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_model(config, **kwargs): model_name = config['MODEL']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/architectures/idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. from typing import Tuple import torch from torch import nn from torch.nn import functional as F import transformers from .registry import register_model from ..utils import configurable, box_ops from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher from ..modules import SetCriterionLLM as SetCriterion from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import torch.distributed as dist import random import os import torchvision from PIL import Image def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, # num_masks, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) # only match the lowest loss # loss = loss.view(-1, 3) # loss = loss.min(1)[0] return loss.sum() # return loss def iou_score_loss(inputs, targets): ce_loss = F.mse_loss(inputs, targets, reduction="none") return ce_loss dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, # num_masks, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.mean(1) # loss = loss.view(-1, 3).min(1)[0] return loss.sum() # return loss sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.sum() class SemanticSAM(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, criterion: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, ade_on=True, merge_class=False, sam_on: bool = True, pascal_part_on: bool = True, regenerate_point: bool = False, num_mask_tokens: int = 3, interactive_pretrain=False, match_loss=True, num_vg=2, vis_out="vis/", coco_old=True, clip_on=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.criterion = criterion self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata self.num_vg = num_vg if size_divisibility < 0: size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.coco_keys = None self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on, 'sam': sam_on, 'pascal_part': pascal_part_on, "ade": ade_on} self.interactive_pretrain = interactive_pretrain self.dbg = False self.positive = 0 self.num_objs = 0 self.num_hits = 0 self.num_refer = 0 self.ref_iou = 0.0 self.random_iou = 0.0 self.match_loss = match_loss self.clip_on = clip_on self.num_all_masks = 0. self.coco_old = coco_old self.multimodal_cfg = {'is_multimodal': True, 'image_token_len': 140, 'use_im_start_end': True} self.logit_scale = nn.Parameter(torch.ones([])) self.vis_out = vis_out self.obj_projector = nn.Linear(256, 4096) print("self.task_switch ", self.task_switch) if not self.semantic_on: assert self.sem_seg_postprocess_before_inference self.max_num_instance = 100 self.num_mask_tokens = num_mask_tokens self.regenerate_point = regenerate_point @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights iou_weight = dec_cfg['IOU_WEIGHT'] class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] refer_weight = dec_cfg['REFER_WEIGHT'] fix_backbone = cfg.get('fix_backbone', False) # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0": box_weight, "loss_giou_0": giou_weight}) weight_dict.update({"iou_score_loss_0": iou_weight}) weight_dict.update({"loss_mask_part_cls_0": class_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"}) dn_losses = ["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses = ["masks", "dn_labels", "boxes"] else: dn_losses = [] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i + 1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["masks", "labels", "boxes"] else: losses = ["masks", "labels", ] if dec_cfg['PART']: losses.append('labels_part') weight_dict.update({'all': 1.0, 'sam': 1.0, 'pas': 1.0}) # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0) weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()} # building criterion criterion = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, # top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], # grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][ 'PANOPTIC_ON'], num_mask_tokens=dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3) ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) if fix_backbone: for name, param in backbone.named_parameters(): param.requires_grad = False # backbone sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra) if fix_backbone: for name, param in sem_seg_head.named_parameters(): param.requires_grad = False return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion": criterion, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER'][ 'SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "sam_on": dec_cfg.get('SAM', True), "pascal_part_on": dec_cfg.get('PASCAL', True), "regenerate_point": dec_cfg.get('RE_POINT', False), "num_mask_tokens": dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3), "ade_on": dec_cfg.get('ADE', False), "interactive_pretrain": dec_cfg.get('pretrain', False), "match_loss": dec_cfg.get('match_loss', True), "vis_out": os.path.join(cfg.get('OUTPUT_DIR', 'out'), str(cfg.get('VIS_OUT', 'vis'))), "coco_old": cfg.get("coco_old", True), # "points_per_side_eval": cfg.get("points_per_side_eval", 30), "clip_on": cfg.get("clip", False), } @property def device(self): return self.pixel_mean.device def evaluate_demo(self, batched_inputs, all_whole=None, all_parts=None, mask_features=None, multi_scale_features=None, return_features=False): assert len(batched_inputs) == 1, "only support batch size equal to 1" prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True} images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) targets = batched_inputs[0]['targets'] height = images[0].shape[1] width = images[0].shape[2] padded_h = images.tensor.shape[-2] # divisable to 32 padded_w = images.tensor.shape[-1] targets[0]['points'] = targets[0]['points'] * torch.as_tensor([width, height, width, height], dtype=torch.float, device=self.device) / torch.as_tensor( [padded_w, padded_h, padded_w, padded_h], dtype=torch.float, device=self.device) features = self.backbone(images.tensor) mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features( features, None) outputs, mask_dict = self.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets, target_queries=None, target_vlp=None, task='demo', extra=prediction_switch) pred_ious = None if 'pred_ious' in outputs.keys(): pred_ious = outputs["pred_ious"] _, index = pred_ious.view(-1, 3).max(1) index = torch.zeros_like(index) obj_feats = outputs['obj_features'][0].view(-1, self.num_mask_tokens, 256) obj_feats = torch.gather(obj_feats, 1, index[..., None, None].repeat(1, 1, 256))[:, 0] mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results.float(), size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) mask_pred_results = mask_pred_results.view(-1, self.num_mask_tokens, images.tensor.shape[-2], images.tensor.shape[-1]) mask_pred_results = torch.gather(mask_pred_results, 1, index[..., None, None, None].repeat(1, 1, images.tensor.shape[-2], images.tensor.shape[-1])) pred_masks = mask_pred_results[:, 0] image_size = images.image_sizes[0] height = image_size[0] width = image_size[1] if self.sem_seg_postprocess_before_inference: pred_masks = retry_if_cuda_oom(sem_seg_postprocess)( pred_masks, image_size, height, width ) return pred_masks, pred_ious, self.obj_projector(obj_feats) def forward(self, batched_inputs, inference_task='seg',detach=False): if self.training: obj_feats,inter_losses= self.forward_det_pretrain(batched_inputs) for k in list(inter_losses.keys()): if k in self.criterion.weight_dict: inter_losses[k] *= self.criterion.weight_dict[k] # losses[k] *= scale else: # remove this loss if not specified in `weight_dict` inter_losses.pop(k) new_losses = {} for key, value in inter_losses.items(): new_losses['inter' + '.' + str(key)] = inter_losses[key] if detach: return [self.obj_projector(feat.detach()) for feat in obj_feats],new_losses else: return [self.obj_projector(feat)[0] for feat in obj_feats],new_losses else: return self.evaluate_demo(batched_inputs) def forward_det_pretrain(self, batched_inputs, task='seg', prediction_switch={'part': True, 'whole': True, 'seg': True, 'det': True}, dataname='coco', semantic=False): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features( features, None) if self.clip_on: image = \ preprocess.preprocess(Image.fromarray(batched_inputs[0]['image_ori']), return_tensors='pt')['pixel_values'][ 0] prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True} # self.criterion.num_classes = len(train_class_names) train_class_names_part = None # if prediction_switch['part']: # train_class_names_part = self.train_class_names[dataname + '_part'] # self.criterion.num_classes_part = len(train_class_names) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets_sam(gt_instances, images, prediction_switch=prediction_switch) else: targets = None print("empty targets", targets, task) if prediction_switch['whole']: prediction_switch['whole'] = False if prediction_switch['part']: prediction_switch['part'] = False # tgt_temp=[] obj_features_ls=[] losses_total=None num_masks=0 for i,tgt in enumerate(targets): tgt_temp=[tgt] outputs_gt, mask_dict = self.sem_seg_head.predictor([feat[i:i+1] for feat in multi_scale_features], mask_features[i:i+1], None, targets=tgt_temp, target_queries=None, target_vlp=None, task='seg', extra=prediction_switch) self.criterion.index = torch.zeros_like(batched_inputs[i]['instances'].gt_classes).to(outputs_gt['obj_features'].device) losses, index = self.criterion(outputs_gt, tgt_temp, mask_dict, task='seg', extra=prediction_switch, return_idx=True) index=self.criterion.index bs, n, h, w = outputs_gt["pred_masks"].shape obj_features = outputs_gt['obj_features'].view(bs, -1, self.num_mask_tokens, 256) obj_features = torch.gather(obj_features, 2, index[None][..., None, None].repeat(1, 1, 1, 256))[:,:, 0] # mask_pred_results = outputs_gt["pred_masks"][0].view(-1, self.num_mask_tokens, h, w) obj_features_ls.append(obj_features) num_masks+=losses['num_masks'] if losses_total is None: losses_total=dict() for key in losses.keys(): if key != 'num_masks': losses_total[key] = losses[key] * losses['num_masks'] else: for key in losses.keys(): if key != 'num_masks': losses_total[key]+=losses[key]*losses['num_masks'] for key in losses_total.keys(): if key != 'num_masks': losses_total[key] = losses_total[key]/num_masks return obj_features_ls,losses_total def prepare_targets_sam(self, targets, images, prediction_switch, task='seg', min_box=0.33, max_box=1.0): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] # box_start = random.randint(int((self.max_num_instance - 1)/2), self.max_num_instance - 1) # box based interactive after this number; about 1/4 # if random.random()<0.5 and self.dbg: # # import pdb;pdb.set_trace() # targets[0]=targets[0][:0] if not self.dbg: self.empty_targets = targets self.dbg = True if len(targets[0]) == 0: empty = True targets = self.empty_targets else: empty = False for targets_per_image in targets: gt_boxes = targets_per_image.gt_boxes if torch.is_tensor( targets_per_image.gt_boxes) else targets_per_image.gt_boxes.tensor # empty=len(gt_boxes)==0 assert len(gt_boxes)>0 self.max_num_instance = len(gt_boxes) box_start = random.randint(int(self.max_num_instance * min_box), int(self.max_num_instance * max_box)) # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) gt_masks = targets_per_image.gt_masks if torch.is_tensor( targets_per_image.gt_masks) else targets_per_image.gt_masks.tensor padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks num_mask = targets_per_image.gt_classes.shape[0] index = torch.arange(num_mask) if self.max_num_instance > num_mask: rep = 0 if num_mask == 0 else int(self.max_num_instance / num_mask) + 1 index = index.repeat(rep) index = index[:self.max_num_instance] # if self.regenerate_point and box_start > 0: point_coords = [] if box_start > 0: for i in range(box_start): mask = gt_masks[index[i]].clone() candidate_indices = mask.nonzero() if len(candidate_indices) == 0: print('wrong') selected_point = torch.tensor([0, 0]).cuda() else: selected_index = random.randint(0, len(candidate_indices) - 1) selected_point = candidate_indices[selected_index].flip(0) selected_point = torch.cat([selected_point - 3, selected_point + 3], 0) point_coords.append(selected_point) point_coords = torch.stack(point_coords).to('cuda') # else: # point_coords = targets_per_image.point_coords[index[:box_start]] # point_coords = targets_per_image.gt_boxes.tensor[index[:box_start]] new_target = { "ori_mask_num": len(targets_per_image.gt_classes), "labels": targets_per_image.gt_classes[index] if prediction_switch['whole'] else None, "masks": padded_masks[index], "boxes": box_ops.box_xyxy_to_cxcywh(gt_boxes[index]) / image_size_xyxy, "points": box_ops.box_xyxy_to_cxcywh(point_coords) / image_size_xyxy if len(point_coords) > 0 else None, # "pb":torch.randint(2,(min(self.max_num_instance,len(targets_per_image.gt_classes)),),device=gt_masks.device), "pb": torch.cat([torch.zeros(box_start), torch.ones(self.max_num_instance - box_start)], 0), "gt_whole_classes": targets_per_image.gt_whole_classes[index] if targets_per_image.has( 'gt_whole_classes') and prediction_switch['whole'] else None, "gt_part_classes": targets_per_image.gt_part_classes[index] if targets_per_image.has( 'gt_part_classes') and prediction_switch['part'] else None, } # handle coco data format if prediction_switch['whole'] and not prediction_switch['part']: new_target['gt_whole_classes'] = targets_per_image.gt_classes[index] if new_target["points"] is not None: new_target["boxes_dn"] = torch.cat([new_target["points"], new_target["boxes"][box_start:]], 0) else: new_target["boxes_dn"] = new_target["boxes"][box_start:] new_target['box_start'] = box_start new_target['empty'] = empty new_targets.append(new_target) return new_targets @register_model def get_segmentation_model(cfg, **kwargs): return SemanticSAM(cfg) ================================================ FILE: llava/model/semsam/architectures/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/backbone/__init__.py ================================================ from .build import build_backbone from .focal import * from .focal_dw import * from .swin import * from .backbone import * ================================================ FILE: llava/model/semsam/backbone/backbone.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch.nn as nn from detectron2.modeling import ShapeSpec # from ..layers import ShapeSpec __all__ = ["Backbone"] class Backbone(nn.Module): """ Abstract base class for network backbones. """ def __init__(self): """ The `__init__` method of any subclass can specify its own set of arguments. """ super().__init__() def forward(self): """ Subclasses must override this method, but adhere to the same return type. Returns: dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor """ pass @property def size_divisibility(self) -> int: """ Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required. """ return 0 def output_shape(self): """ Returns: dict[str->ShapeSpec] """ # this is a backward-compatible default return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } ================================================ FILE: llava/model/semsam/backbone/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .backbone import * def build_backbone(config, **kwargs): model_name = config['MODEL']['BACKBONE']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/backbone/focal.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) if self.use_postln: x = self.norm1(x) # FFN x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 2; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/semsam/backbone/focal_dw.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw1(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw2(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) if not self.use_postln: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(x)) x = self.norm2(x) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, use_pre_norm=False, ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False, use_pre_norm=use_pre_norm ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W # class PatchEmbed(nn.Module): # r""" Image to Patch Embedding # Args: # img_size (int): Image size. Default: 224. # patch_size (int): Patch token size. Default: 4. # in_chans (int): Number of input image channels. Default: 3. # embed_dim (int): Number of linear projection output channels. Default: 96. # norm_layer (nn.Module, optional): Normalization layer. Default: None # """ # def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, # use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False): # super().__init__() # patch_size = to_2tuple(patch_size) # patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # self.img_size = img_size # self.patch_size = patch_size # self.patches_resolution = patches_resolution # self.num_patches = patches_resolution[0] * patches_resolution[1] # self.in_chans = in_chans # self.embed_dim = embed_dim # self.use_pre_norm = use_pre_norm # if use_conv_embed: # # if we choose to use conv embedding, then we treat the stem and non-stem differently # if is_stem: # kernel_size = 7; padding = 3; stride = 4 # else: # kernel_size = 3; padding = 1; stride = 2 # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) # else: # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # if self.use_pre_norm: # if norm_layer is not None: # self.norm = norm_layer(in_chans) # else: # self.norm = None # else: # if norm_layer is not None: # self.norm = norm_layer(embed_dim) # else: # self.norm = None # def forward(self, x): # B, C, H, W = x.shape # # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # if self.use_pre_norm: # if self.norm is not None: # x = x.flatten(2).transpose(1, 2) # B Ph*Pw C # x = self.norm(x).transpose(1, 2).view(B, C, H, W) # x = self.proj(x).flatten(2).transpose(1, 2) # else: # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C # if self.norm is not None: # x = self.norm(x) # return x # def flops(self): # Ho, Wo = self.patches_resolution # flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # if self.norm is not None: # flops += Ho * Wo * self.embed_dim # return flops class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.use_pre_norm = use_pre_norm if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 3; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if self.use_pre_norm: if norm_layer is not None: self.norm = norm_layer(in_chans) else: self.norm = None else: if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" B, C, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) if self.use_pre_norm: if self.norm is not None: x = x.flatten(2).transpose(1, 2) # B Ph*Pw C x = self.norm(x).transpose(1, 2).view(B, C, H, W) x = self.proj(x) else: x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_pre_norms=[False, False, False, False], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_pre_norm=use_pre_norms[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # self.norm = norm_layer(num_features[-1]) # add a norm layer for each output for i_layer in self.out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/semsam/backbone/registry.py ================================================ _model_entrypoints = {} def register_backbone(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/backbone/swin.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import Backbone, ShapeSpec from detectron2.utils.file_io import PathManager from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" # HACK model will not upsampling # if min([H, W]) <= self.window_size: # if window size is larger than input resolution, we don't partition windows # self.shift_size = 0 # self.window_size = min([H,W]) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ).type(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): relative_position_bias_table_pretrained = v relative_position_bias_table_current = model_dict[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): absolute_pos_embed_pretrained = v absolute_pos_embed_current = model_dict[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((1, L1, C1), (1, L2, C2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint): super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) self._out_features = cfg['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features)) return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in feature_names } @property def size_divisibility(self): return 32 @register_backbone def get_swin_backbone(cfg): swin_cfg = cfg['MODEL']['BACKBONE']['SWIN'] pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE'] patch_size = swin_cfg['PATCH_SIZE'] in_chans = 3 embed_dim = swin_cfg['EMBED_DIM'] depths = swin_cfg['DEPTHS'] num_heads = swin_cfg['NUM_HEADS'] window_size = swin_cfg['WINDOW_SIZE'] mlp_ratio = swin_cfg['MLP_RATIO'] qkv_bias = swin_cfg['QKV_BIAS'] qk_scale = swin_cfg['QK_SCALE'] drop_rate = swin_cfg['DROP_RATE'] attn_drop_rate = swin_cfg['ATTN_DROP_RATE'] drop_path_rate = swin_cfg['DROP_PATH_RATE'] norm_layer = nn.LayerNorm ape = swin_cfg['APE'] patch_norm = swin_cfg['PATCH_NORM'] use_checkpoint = swin_cfg['USE_CHECKPOINT'] out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3]) swin = D2SwinTransformer( swin_cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] with PathManager.open(filename, "rb") as f: # ckpt = torch.load(f, map_location=cfg['device'])['model'] ckpt = torch.load(f, map_location='cpu')['model'] swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return swin ================================================ FILE: llava/model/semsam/backbone/swin_new.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() @BACKBONE_REGISTRY.register() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE patch_size = cfg.MODEL.SWIN.PATCH_SIZE in_chans = 3 embed_dim = cfg.MODEL.SWIN.EMBED_DIM depths = cfg.MODEL.SWIN.DEPTHS num_heads = cfg.MODEL.SWIN.NUM_HEADS window_size = cfg.MODEL.SWIN.WINDOW_SIZE mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO qkv_bias = cfg.MODEL.SWIN.QKV_BIAS qk_scale = cfg.MODEL.SWIN.QK_SCALE drop_rate = cfg.MODEL.SWIN.DROP_RATE attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE norm_layer = nn.LayerNorm ape = cfg.MODEL.SWIN.APE patch_norm = cfg.MODEL.SWIN.PATCH_NORM use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, use_checkpoint=use_checkpoint, ) self._out_features = cfg.MODEL.SWIN.OUT_FEATURES self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 ================================================ FILE: llava/model/semsam/body/__init__.py ================================================ from .build import build_openseed_head ================================================ FILE: llava/model/semsam/body/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .openseed_head import * def build_openseed_head(config, *args, **kwargs): model_name = config['MODEL']['HEAD'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') body = model_entrypoints(model_name)(config, *args, **kwargs) return body ================================================ FILE: llava/model/semsam/body/decoder/__init__.py ================================================ from .build import build_decoder from .idino_decoder_no_iou_token_partwhole_all_llm import * ================================================ FILE: llava/model/semsam/body/decoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_decoder(config, *args, **kwargs): model_name = config['MODEL']['DECODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/semsam/body/decoder/idino_decoder_no_iou_token_partwhole_all_llm.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class MaskDINODecoder(nn.Module): @configurable def __init__( self, lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, num_mask_tokens: int = 3, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.semantic_ce_loss = semantic_ce_loss interactive_only = True # learnable query features if num_queries>0 and not interactive_only: if not two_stage or self.learn_tgt: self.query_feat = nn.Embedding(num_queries, hidden_dim) if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries, 4) # if two_stage: # self.enc_output = nn.Linear(hidden_dim, hidden_dim) # self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj self.lang_encoder = lang_encoder # if lang_encoder is not None: self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed # whole category classification self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) # part category classification self.class_embed_part = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed_part, std=.02) # FIXME iou head; iou prediction: 1. iou token to predict 3 score. 2. predict each iou score from query tokens # FIXME seems we only need to stack these tokens in batch dimension to reduce self attention burden. self.num_mask_tokens = num_mask_tokens # sam uses 4 to handle multi prompts self.iou_token = 0 # FIXME hack to remove iou token self.num_all_tokens = self.num_mask_tokens + self.iou_token # sam uses 4 to handle multi prompts self.iou_prediction_head = MLP(hidden_dim, hidden_dim, 1, 3) # self.iou_token = nn.Embedding(self.iou_token, hidden_dim) self.mask_tokens = nn.Embedding(self.num_mask_tokens, hidden_dim) self.pb_embedding=nn.Embedding(2,hidden_dim) self.label_enc=nn.Embedding(2,hidden_dim) self.prediction_switch = None @classmethod def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): ret = {} ret["in_channels"] = in_channels ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] # Transformer parameters: ret["num_mask_tokens"] = dec_cfg.get('NUM_MASK_TOKENS', 3) ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_mask_tokens"] = dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3) ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def prepare_for_dn_o3(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = 1 else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = self.pb_embedding(torch.ones_like(m)) input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def prepare_for_dn_mo(self, targets, tgt, refpoint_emb, batch_size): # if self.training: scalar, noise_scale = self.dn_num,self.noise_scale known = [(torch.ones_like(t['boxes'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [k.sum() for k in known] if max(known_num)>0: scalar=1 # FIXME this is wrong attention mask!!! else: scalar=0 if scalar==0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # return input_query_label, input_query_bbox, attn_mask, mask_dict pb_labels = torch.stack([t['pb'] for t in targets]) # FIXME this is for future content-based interaction; pool content features as label embedding labels = torch.zeros_like(pb_labels).long() boxes = torch.stack([t['boxes_dn'] for t in targets]) box_start = [t['box_start'] for t in targets] known_labels = labels known_pb_labels = pb_labels known_bboxs = boxes known_labels_expaned = known_labels.clone() known_pb_labels_expaned = known_pb_labels.clone() known_bbox_expand = known_bboxs.clone() ############ noise on the label # if noise_scale > 0: # p = torch.rand_like(known_labels_expaned.float()) # chosen_indice = torch.nonzero(p < (noise_scale * 0.5)).view(-1) # half of bbox prob # new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here # known_labels_expaned.scatter_(0, chosen_indice, new_label) if noise_scale > 0 and self.training: diff = torch.zeros_like(known_bbox_expand) diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 diff[:, :, 2:] = known_bbox_expand[:, :, 2:] sc = 0.01 for i, st in enumerate(box_start): diff[i, :st] = diff[i, :st] * sc known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale # known_bbox_expand+=(torch.rand_like(known_bbox_expand)*2-1.0)*torch.tensor([[1,1,0.1,0.1]]).cuda()*noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') m_pb = known_pb_labels_expaned.long().to('cuda') input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb) input_bbox_embed = inverse_sigmoid(known_bbox_expand) input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1) input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1) single_pad = self.num_all_tokens # NOTE scalar is modified to 100, each click cannot see each other scalar = int(input_label_embed.shape[1]/self.num_all_tokens) pad_size = input_label_embed.shape[1] if input_label_embed.shape[1]>0: input_query_label = input_label_embed input_query_bbox = input_bbox_embed tgt_size = pad_size attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_lbs_bboxes': (known_labels, known_bboxs), # 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label,input_query_bbox,attn_mask,mask_dict def prepare_for_dn_mo_infer(self, targets, tgt, refpoint_emb, batch_size): known = [(torch.ones_like(t['points'])).cuda() for t in targets] known_num = [k.sum() for k in known] assert max(known_num)>0 pb_labels = torch.stack([t['pb'] for t in targets]) # FIXME this is for future content-based interaction; pool content features as label embedding labels = torch.zeros_like(pb_labels).long() boxes = torch.stack([t['points'] for t in targets]) known_labels = labels known_pb_labels = pb_labels known_bboxs = boxes known_labels_expaned = known_labels.clone() known_pb_labels_expaned = known_pb_labels.clone() known_bbox_expand = known_bboxs.clone() m = known_labels_expaned.long().to('cuda') m_pb = known_pb_labels_expaned.long().to('cuda') input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb) input_bbox_embed = inverse_sigmoid(known_bbox_expand) input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1) input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1) scalar = int(input_label_embed.shape[1]/self.num_all_tokens) pad_size = input_label_embed.shape[1] if input_label_embed.shape[1]>0: input_query_label = input_label_embed input_query_bbox = input_bbox_embed attn_mask = None mask_dict = { 'known_lbs_bboxes': (known_labels, known_bboxs), # 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } return input_query_label,input_query_bbox,attn_mask,mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) # layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) new_layer_ref_sig = layer_ref_sig.view(layer_ref_sig.shape[0], -1, self.num_all_tokens, layer_ref_sig.shape[-1]) new_layer_ref_sig = new_layer_ref_sig[:, :, :self.num_mask_tokens].reshape(new_layer_ref_sig.shape[0], -1, new_layer_ref_sig.shape[-1]) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(new_layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def pred_box_old(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det TODO add sam """ # task = 'sam' prediction_switch = extra self.prediction_switch = prediction_switch assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_class_part = [] predictions_mask = [] predictions_iou_score = [] tgt_mask = None mask_dict = None if self.dn != "no": assert targets is not None if task=='demo': input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_mo_infer(targets, None, None, x[0].shape[0]) else: input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_mo(targets, None, None, x[0].shape[0]) tgt=input_query_label refpoint_embed=input_query_bbox if tgt is None: tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda() refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda() # import pdb;pdb.set_trace() refpoint_embed=refpoint_embed.to(tgt.dtype) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) new_hs = [] feats=[] for i, output in enumerate(hs): outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output = self.idno_forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) outputs_class_whole, outputs_class_part = outputs_class predictions_class.append(outputs_class_whole) predictions_class_part.append(outputs_class_part) predictions_mask.append(outputs_mask) feats.append(decoder_output) if iou_score is not None: predictions_iou_score.append(iou_score) new_hs.append(decoder_output_mask) if new_hs is not None: hs = new_hs # iteratively box prediction out_boxes = self.pred_box(references, hs) out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum() + self.mask_tokens.weight.sum() + self.lang_mapper.sum()+iou_score.sum()) if mask_dict is not None: if predictions_mask is None: predictions_class[-1] = predictions_class[-1] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'obj_features': feats[-1], 'pred_logits_part': predictions_class_part[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'pred_ious': predictions_iou_score[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask, out_boxes, predictions_iou_score, predictions_class_part ) } return out, mask_dict def forward_o365(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det TODO add sam """ # task = 'sam' prediction_switch = extra self.prediction_switch = prediction_switch assert len(x) == self.num_feature_levels do_seg = False # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] # predictions_class_part = [] predictions_mask = [] # predictions_iou_score = [] tgt_mask = None mask_dict = None # if self.dn != "no": assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_o3(targets, None, None, x[0].shape[0]) tgt=input_query_label refpoint_embed=input_query_bbox if tgt is None: tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda() refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda() hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) # new_hs = [] for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) outputs_class_whole = outputs_class predictions_class.append(outputs_class_whole) # predictions_class_part.append(outputs_class_part) predictions_mask.append(outputs_mask) # if iou_score is not None: # predictions_iou_score.append(iou_score) # new_hs.append(decoder_output_mask) # if new_hs is not None: # hs = new_hs # iteratively box prediction out_boxes = self.pred_box_old(references, hs) out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum() + self.mask_tokens.weight.sum() + self.lang_mapper.sum()) if mask_dict is not None: if predictions_mask is None: predictions_class[-1] = predictions_class[-1] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], # 'pred_logits_part': predictions_class_part[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], # 'pred_ious': predictions_iou_score[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask, out_boxes ) } return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.lang_encoder.compute_similarity(class_embed, name='whole') outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask def idno_forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) decoder_output = decoder_output + 0.0 * (self.class_embed_part.sum() + self.class_embed.sum()) out = decoder_output.view(decoder_output.shape[0], -1, self.num_all_tokens, decoder_output.shape[-1]) decoder_output_mask = out[:, :, :self.num_mask_tokens].reshape(decoder_output.shape[0], -1, decoder_output.shape[-1]) # decoder_output_iou = out[:, :, -1].view(decoder_output.shape[0], -1, decoder_output.shape[-1]) decoder_output_iou = decoder_output_mask outputs_mask = outputs_class_whole = outputs_class_part = None if self.prediction_switch['whole']: class_embed_whole = decoder_output @ self.class_embed outputs_class_whole = self.lang_encoder.compute_similarity(class_embed_whole, name='whole') if self.prediction_switch['part']: class_embed_part = decoder_output @ self.class_embed_part outputs_class_part = self.lang_encoder.compute_similarity(class_embed_part, name='part') outputs_class = (outputs_class_whole, outputs_class_part) if self.prediction_switch['seg']: mask_embed = self.mask_embed(decoder_output_mask) if mask_embed.dtype==torch.float16 and mask_features.dtype==torch.float32: mask_embed=mask_embed.to(torch.float32) if mask_embed.dtype==torch.float32 and mask_features.dtype==torch.float16: mask_features=mask_features.to(torch.float32) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features.to(mask_embed.dtype)) iou_score = self.iou_prediction_head(decoder_output_iou).squeeze(-1).view(decoder_output.shape[0], -1, self.num_mask_tokens) # outputs_mask = outputs_mask + 0.0 * iou_score.sum() # TODO add iou prediction head return outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output @torch.jit.unused def _set_aux_loss(self, outputs_class=None, outputs_seg_masks=None, out_boxes=None, predictions_iou_score=None, predictions_class_part=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] elif predictions_iou_score is None: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c, "pred_ious":d, "pred_logits_part": e} for a, b, c, d, e in zip(outputs_class[:-1], outputs_seg_masks[:-1],out_boxes[:-1], predictions_iou_score[:-1], predictions_class_part[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra) ================================================ FILE: llava/model/semsam/body/decoder/modules.py ================================================ from typing import Optional import torch from torch import nn, Tensor from torch.nn import functional as F from timm.models.layers import trunc_normal_ from detectron2.layers import Conv2d import fvcore.nn.weight_init as weight_init class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt, avg_attn def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) return tgt, avg_attn def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x ================================================ FILE: llava/model/semsam/body/decoder/registry.py ================================================ _model_entrypoints = {} def register_decoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/decoder/utils/__init__.py ================================================ from .utils import * ================================================ FILE: llava/model/semsam/body/decoder/utils/dino_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Optional, List, Union import torch from torch import nn, Tensor from torch.cuda.amp import autocast from .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid from ...encoder.ops.modules import MSDeformAttn from torch.utils.checkpoint import checkpoint class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=True, num_feature_levels=1, deformable_decoder=True, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=True, dec_layer_share=False, dec_layer_dropout_prob=None, task_switch=None, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.task_switch = task_switch self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # misc extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys(): reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0) output = torch.cat((output, extra['grounding_tokens']), dim=0) for layer_id, layer in enumerate(self.layers): # preprocess ref points if self.training and self.decoder_query_perturber is not None and layer_id != 0: reference_points = self.decoder_query_perturber(reference_points) reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :].to(reference_points.dtype) # nq, bs, nlevel, 4 # print('reference_points_input', reference_points_input.dtype) # print('memory', memory.dtype) # reference_points_input=reference_points_input.to(memory.dtype) query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2 # import pdb; pdb.set_trace() # query_sine_embed = query_sine_embed.to(self.ref_point_head.layers[0].weight.dtype) raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos output = layer( output, query_pos, query_sine_embed, tgt_key_padding_mask, reference_points_input, memory, memory_key_padding_mask, level_start_index, spatial_shapes, pos, tgt_mask, memory_mask, self.task_switch, extra, ) # grounding language token reference point will not update and saved if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _reference_points = reference_points[-extra['grounding_len']:] reference_points = reference_points[:-extra['grounding_len']] _output = output[-extra['grounding_len']:] output = output[:-extra['grounding_len']] # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) # import pdb; pdb.set_trace() output= output.to(query_sine_embed.dtype) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) # add back grounding language token if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': reference_points = torch.cat((reference_points, _reference_points)) output = torch.cat((output, _output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, key_aware_type=None, ): super().__init__() # cross attention if use_deformable_box_attn: raise NotImplementedError else: self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt @autocast(enabled=True) def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # misc task_switch: Optional[Tensor] = {}, # extra information extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention # import pdb;pdb.set_trace() if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # exclude grounding token for cross attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _grounding_lang_tokens = tgt[-extra['grounding_len']:,] _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,] _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,] tgt = tgt[:-extra['grounding_len'],] tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],] tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],] # cross attention if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) # import pdb;pdb.set_trace() tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better # import pdb;pdb.set_trace() tgt = tgt + self.dropout1(tgt2) # add back grounding token for self attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': tgt = torch.cat((tgt, _grounding_lang_tokens)) tgt = self.norm1(tgt) tgt = self.forward_ffn(tgt) # ffn return tgt ================================================ FILE: llava/model/semsam/body/decoder/utils/utils.py ================================================ import torch import copy from torch import nn, Tensor import os import math import torch.nn.functional as F from torch import nn class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) return output_memory, output_proposals def gen_sineembed_for_position(pos_tensor, dim=128): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * (dim_t // 2) / dim) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos.to(pos_tensor.dtype) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == "prelu": return nn.PReLU() if activation == "selu": return F.selu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) ================================================ FILE: llava/model/semsam/body/encoder/__init__.py ================================================ from .build import build_encoder ================================================ FILE: llava/model/semsam/body/encoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .transformer_encoder_fpn import * from .encoder_deform import * def build_encoder(config, *args, **kwargs): model_name = config['MODEL']['ENCODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/semsam/body/encoder/encoder_deform.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified by Feng Li and Hao Zhang. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast from detectron2.layers import Conv2d, ShapeSpec, get_norm # from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_encoder from ...utils import configurable from ...modules import PositionEmbeddingSine from ..transformer_blocks import _get_clones, _get_activation_fn from .ops.modules import MSDeformAttn from torch.utils import checkpoint # MSDeformAttn Transformer encoder in deformable detr class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4,): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, use_ckpt=False): enable_mask=0 if masks is not None: for src in srcs: if src.size(2)%32 or src.size(3)%32: enable_mask = 1 if enable_mask==0: masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt) return memory, spatial_shapes, level_start_index class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False): output = src reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): use_ckpt = False if use_ckpt: output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) else: output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class MaskDINOEncoder(nn.Module): """ This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models. """ @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, # deformable transformer encoder args transformer_in_features: List[str], common_stride: int, num_feature_levels: int, total_num_feature_levels: int, feature_order: str, use_ckpt=False, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers num_feature_levels: feature scales used total_num_feature_levels: total feautre scales used (include the downsampled features) feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first. """ super().__init__() self.use_ckpt = use_ckpt transformer_input_shape = { k: v for k, v in input_shape.items() if k in transformer_in_features } # this is the input shape of pixel decoder input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" self.feature_strides = [v.stride for k, v in input_shape] self.feature_channels = [v.channels for k, v in input_shape] self.feature_order = feature_order if feature_order == "low2high": transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride) else: transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" transformer_in_channels = [v.channels for k, v in transformer_input_shape] self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales self.total_num_feature_levels = total_num_feature_levels self.common_stride = common_stride self.transformer_num_feature_levels = len(self.transformer_in_features) self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels)) self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1 if self.transformer_num_feature_levels > 1: input_proj_list = [] for in_channels in transformer_in_channels[::-1]: input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )) # input projectino for downsample in_channels = max(transformer_in_channels) for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2 input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, conv_dim), )) in_channels = conv_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )]) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, num_feature_levels=self.total_num_feature_levels, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) self.mask_dim = mask_dim # use 1x1 conv instead self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=1, stride=1, padding=0, ) weight_init.c2_xavier_fill(self.mask_features) # extra fpn levels stride = min(self.transformer_feature_strides) self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1) lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] # deformable transformer encoder ret[ "transformer_enc_layers" ] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_in_features"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES'] # ['res3', 'res4', 'res5'] ret["common_stride"] = enc_cfg['COMMON_STRIDE'] ret["total_num_feature_levels"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_feature_levels"] = enc_cfg['NUM_FEATURE_LEVELS'] ret["feature_order"] = enc_cfg['FEATURE_ORDER'] ret["use_ckpt"] = enc_cfg.get('USE_CKPT', False) return ret @autocast(enabled=True) def forward_features(self, features, masks): """ :param features: multi-scale features from the backbone :param masks: image mask :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask """ # backbone features srcs = [] pos = [] # additional downsampled features srcsl = [] posl = [] # import pdb; pdb.set_trace() if self.total_num_feature_levels > self.transformer_num_feature_levels: smallest_feat = features[self.transformer_in_features[self.low_resolution_index]]#.float() _len_srcs = self.transformer_num_feature_levels for l in range(_len_srcs, self.total_num_feature_levels): if l == _len_srcs: src = self.input_proj[l](smallest_feat) else: src = self.input_proj[l](srcsl[-1]) srcsl.append(src) posl.append(self.pe_layer(src)) srcsl = srcsl[::-1] # Reverse feature maps for idx, f in enumerate(self.transformer_in_features[::-1]): x = features[f]#.float() # deformable detr does not support half precision srcs.append(self.input_proj[idx](x)) pos.append(self.pe_layer(x)) srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs) pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos) if self.feature_order != 'low2high': srcs = srcsl pos = posl y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt) bs = y.shape[0] split_size_or_sections = [None] * self.total_num_feature_levels for i in range(self.total_num_feature_levels): if i < self.total_num_feature_levels - 1: split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] else: split_size_or_sections[i] = y.shape[1] - level_start_index[i] y = torch.split(y, split_size_or_sections, dim=1) out = [] multi_scale_features = [] num_cur_levels = 0 for i, z in enumerate(y): out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): x = features[f]#.float() lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) y = output_conv(y) out.append(y) for o in out: if num_cur_levels < self.total_num_feature_levels: multi_scale_features.append(o) num_cur_levels += 1 return self.mask_features(out[-1]), out[0], multi_scale_features @register_encoder def get_maskdino_encoder_deform(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = MaskDINOEncoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/semsam/body/encoder/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: llava/model/semsam/body/encoder/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable try: import MultiScaleDeformableAttention as MSDA except ModuleNotFoundError as e: info_string = ( "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" "\t`cd mask2former/modeling/pixel_decoder/ops`\n" "\t`sh make.sh`\n" ) raise ModuleNotFoundError(info_string) class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: llava/model/semsam/body/encoder/ops/make.sh ================================================ #!/usr/bin/env bash # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR python setup.py build install --user ================================================ FILE: llava/model/semsam/body/encoder/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn import MSDeformAttn ================================================ FILE: llava/model/semsam/body/encoder/ops/modules/ms_deform_attn.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 128 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in # input_flatten=input_flatten.to(self.value_proj.bias.data.dtype) value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) # try: # print(value.dtype) convert=False # import pdb; pdb.set_trace() dtype=value.dtype if value.dtype== torch.bfloat16 or value.dtype== torch.float16: value = value.float() attention_weights = attention_weights.float() sampling_locations = sampling_locations.float() convert=True # value = value.float() # attention_weights = attention_weights.float() # sampling_locations = sampling_locations.float() # convert=True output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) if convert: output = output.to(dtype) # except: # # CPU # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) # # For FLOPs calculation only # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output ================================================ FILE: llava/model/semsam/body/encoder/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] # Force cuda since torch ask for a device, not if cuda is in fact available. if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: if CUDA_HOME is None: raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') else: raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: llava/model/semsam/body/encoder/registry.py ================================================ _model_entrypoints = {} def register_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/encoder/transformer_encoder_fpn.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast import fvcore.nn.weight_init as weight_init from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm from .registry import register_encoder from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn from ...modules import PositionEmbeddingSine from ...utils import configurable # This is a modified FPN decoder. class BasePixelDecoder(nn.Module): def __init__( self, input_shape: Dict[str, ShapeSpec], *, conv_dim: int, mask_dim: int, mask_on: bool, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_channels = [v.channels for k, v in input_shape] lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(feature_channels): if idx == len(self.in_features) - 1: output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( in_channels, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(None) output_convs.append(output_conv) else: lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.mask_on = mask_on if self.mask_on: self.mask_dim = mask_dim self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=3, stride=1, padding=1, ) weight_init.c2_xavier_fill(self.mask_features) self.maskformer_num_feature_levels = 3 # always use 3 scales @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: y = output_conv(x) else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, None, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) class TransformerEncoderOnly(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) if mask is not None: mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) return memory.permute(1, 2, 0).view(bs, c, h, w) # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map. class TransformerEncoderPixelDecoder(BasePixelDecoder): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, transformer_pre_norm: bool, conv_dim: int, mask_dim: int, mask_on: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers transformer_pre_norm: whether to use pre-layernorm or not conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on) input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] in_channels = feature_channels[len(self.in_features) - 1] self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) self.transformer = TransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, normalize_before=transformer_pre_norm, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # update layer use_bias = norm == "" output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) delattr(self, "layer_{}".format(len(self.in_features))) self.add_module("layer_{}".format(len(self.in_features)), output_conv) self.output_convs[0] = output_conv @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = super().from_config(cfg, input_shape) ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_pre_norm"] = dec_cfg['PRE_NORM'] ret['mask_on'] = cfg['MODEL']['DECODER']['MASK'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: transformer = self.input_proj(x) pos = self.pe_layer(x) transformer = self.transformer(transformer, None, pos) y = output_conv(transformer) # save intermediate feature as input to Transformer decoder transformer_encoder_features = transformer else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, transformer_encoder_features, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) @register_encoder def get_transformer_encoder_fpn(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = TransformerEncoderPixelDecoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/semsam/body/openseed_head.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. # ------------------------------------------------------------------------------ import logging from typing import Callable, Dict, List, Optional, Tuple, Union from torch import nn from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_body from .encoder import build_encoder from .decoder import build_decoder from ..utils import configurable class MaskDINOHead(nn.Module): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, transformer_predictor: nn.Module, ): """ Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. transformer_predictor: the transformer decoder that makes prediction transformer_in_feature: input feature name to the transformer_predictor """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = 4 self.loss_weight = loss_weight self.pixel_decoder = pixel_decoder self.predictor = transformer_predictor self.num_classes = num_classes @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] return { "input_shape": { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] }, "ignore_value": enc_cfg['IGNORE_VALUE'], "num_classes": enc_cfg.get('NUM_CLASSES', None), "pixel_decoder": build_encoder(cfg, input_shape), "loss_weight": enc_cfg['LOSS_WEIGHT'], "transformer_predictor": build_decoder( cfg, transformer_predictor_in_channels, lang_encoder, mask_classification=True, extra=extra, ), } def forward(self, features, mask=None, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): return self.layers(features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) def layers(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) if task == 'teacher': predictions = self.predictor.forward_teacher(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) else: predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) return predictions @register_body def get_maskdino_head(cfg, input_shape, lang_encoder, extra): return MaskDINOHead(cfg, input_shape, lang_encoder, extra) ================================================ FILE: llava/model/semsam/body/registry.py ================================================ _model_entrypoints = {} def register_body(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/transformer_blocks.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py """ Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ import copy from typing import List, Optional import torch import torch.nn.functional as F from torch import Tensor, nn class Transformer(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, ) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) if mask is not None: mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder( tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed ) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos ) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") ================================================ FILE: llava/model/semsam/language/LangEncoder/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .build import build_lang_encoder from .build import build_tokenizer from .transformer import * ================================================ FILE: llava/model/semsam/language/LangEncoder/build.py ================================================ import os from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers import AutoTokenizer from .registry import lang_encoders from .registry import is_lang_encoder def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): model_name = config_encoder['NAME'] if not is_lang_encoder(model_name): raise ValueError(f'Unkown model: {model_name}') return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) def build_tokenizer(config_encoder): tokenizer = None os.environ['TOKENIZERS_PARALLELISM'] = 'true' if config_encoder['TOKENIZER'] == 'clip': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) elif config_encoder['TOKENIZER'] == 'clip-fast': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) else: tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) return tokenizer ================================================ FILE: llava/model/semsam/language/LangEncoder/registry.py ================================================ _lang_encoders = {} def register_lang_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _lang_encoders[model_name] = fn return fn def lang_encoders(model_name): return _lang_encoders[model_name] def is_lang_encoder(model_name): return model_name in _lang_encoders ================================================ FILE: llava/model/semsam/language/LangEncoder/transformer.py ================================================ from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .registry import register_lang_encoder from detectron2.utils.comm import is_main_process from utils.model import register_norm_module logger = logging.getLogger(__name__) @register_norm_module class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn( x, x, x, key_padding_mask=key_padding_mask, need_weights=False, attn_mask=self.attn_mask )[0] def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0, autogressive: bool =True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers self.autogressive = autogressive attn_mask = self.build_attention_mask() if autogressive else None dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.ModuleList( [ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) @property def dim_out(self): return self.width def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): if is_main_process(): logger.info('=> init weight of Linear/Conv2d from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: if is_main_process(): logger.info('=> init bias of Linear/Conv2d to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x pretrained_dict = { stripped_key(k): v for k, v in pretrained_dict.items() if stripped_key(k) in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if verbose: logger.info(f'=> init {k} from {pretrained}') if 'positional_embedding' in k and v.size() != model_dict[k].size(): positional_embedding_pretrained = v positional_embedding_current = model_dict[k] L1, nH1 = positional_embedding_pretrained.size() L2, nH2 = positional_embedding_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) posemb = positional_embedding_pretrained.float() posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) v = posemb_grid need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, input_ids, attention_mask=None): key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None # key_padding_mask = (input_ids == 0) if not self.autogressive else None x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND for block in self.resblocks: x = block(x, key_padding_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) return {'last_hidden_state': x} @register_lang_encoder def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): transformer = Transformer( context_length=config_encoder['CONTEXT_LENGTH'], vocab_size=tokenizer.vocab_size, width=config_encoder['WIDTH'], layers=config_encoder['LAYERS'], heads=config_encoder['HEADS'], autogressive=config_encoder.get('AUTOGRESSIVE', True) ) if config_encoder.get('LOAD_PRETRAINED', False): transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*'])) return transformer ================================================ FILE: llava/model/semsam/language/__init__.py ================================================ # from .vlpencoder import * # from .encoder import * # from .fixencoder import * # from .loss import * # from .modeling_llama_os import LlamaForCausalLM # # from .modeling_llama_os_lora import LlamaForCausalLMLora # from .llama_encoder import * # from .build import build_language_encoder ================================================ FILE: llava/model/semsam/language/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_language_encoder(config, **kwargs): model_name = config['MODEL']['TEXT']['ARCH'] if model_name=='noencoder': return None if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/language/encoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } # @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) # @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/semsam/language/fixencoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x @torch.no_grad() # FIXME hack to freeze all parameters def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/semsam/language/llama_encoder.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence import torch import transformers from torch.utils.data import Dataset from transformers import Trainer from llava import conversation as conversation_lib from PIL import Image import torch.nn as nn # from openseed.BaseModel import BaseModel # from openseed import build_model import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates from openseed.language import LlamaForCausalLM # TODO: import and use code from ../data/dataset.py IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" DEFAULT_OBJECT_START_TOKEN = "" DEFAULT_OBJECT_END_TOKEN = "" ENC_LENS=[140*64,140*16,140*4,140] ENC_ID=-1 @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") freeze_backbone: bool = field(default=False) dbg: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) config_file: Optional[str] = field(default="") os_weights:Optional[str]=field(default="") vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) pretrain_obj_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_token_len: int = 0 image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) # dbg: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], multimodal_cfg: dict, cur_token_len: int, ) -> Dict: is_multimodal = multimodal_cfg['is_multimodal'] # image_token_len = multimodal_cfg['image_token_len'] image_token_len = cur_token_len if not is_multimodal: return sources for source in sources: for sentence in source: replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if multimodal_cfg['use_im_start_end']: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): super(SupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...") sources = [example["conversations"] for example in list_data_dict] data_dict = preprocess(sources, tokenizer) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] def __len__(self): return len(self.input_ids) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict(input_ids=self.input_ids[i], labels=self.labels[i]) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, multimodal_cfg: dict): super(LazySupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.multimodal_cfg = multimodal_cfg def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME try: if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.multimodal_cfg['image_folder'] processor = self.multimodal_cfg['image_processor'] image = Image.open(os.path.join(image_folder, image_file)) if self.multimodal_cfg['image_aspect_ratio'] == 'keep': max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 1333, 800 shortest_edge = int(min(max_len / aspect_ratio, min_len)) try: image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge}, do_rescale=False, do_normalize=False)['pixel_values'][0] except Exception: return self.__getitem__(i + 1) else: # try: image = processor.preprocess(image, return_tensors='pt', do_rescale=False, do_normalize=False, do_center_crop=False,size=(640,64*14))[ 'pixel_values'][0] # except Exception: # return self.__getitem__(i+1) # FIXME: cur_token_len should be num_queries when using det # cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size cur_token_len = ENC_LENS[ENC_ID] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.multimodal_cfg, cur_token_len) else: sources = copy.deepcopy([e["conversations"] for e in sources]) except Exception: return self.__getitem__(i + 1) data_dict = preprocess( sources, self.tokenizer) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.multimodal_cfg['is_multimodal']: # image does not exist in the data, but the model is multimodal crop_size = self.multimodal_cfg['image_processor'].crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" dataset_cls = (LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset) train_dataset = dataset_cls(tokenizer=tokenizer, data_path=data_args.data_path, multimodal_cfg=dict( is_multimodal=data_args.is_multimodal, image_token_len=data_args.image_token_len, image_folder=data_args.image_folder, image_aspect_ratio=data_args.image_aspect_ratio, use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False), image_processor=getattr(data_args, 'image_processor', None))) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) # from detectron2.config import get_cfg, CfgNode from detectron2.config import LazyConfig, instantiate # from detectron2.utils.logger import setup_logger # from detectron2.engine import default_setup def setup(config_file): """ Create configs and perform basic setups. """ cfg = LazyConfig.load(config_file) # cfg = LazyConfig.apply_overrides(cfg, args.opts) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg @register_model def get_language_model(cfg, **kwargs): llama_cfg = cfg['MODEL']['LLAMA'] if llama_cfg['load_fp16']: return LlamaForCausalLM.from_pretrained( llama_cfg['model_name_or_path'], cache_dir=llama_cfg['cache_dir'], torch_dtype=torch.float16 ) else: return LlamaForCausalLM.from_pretrained( llama_cfg['model_name_or_path'], cache_dir=llama_cfg['cache_dir'], # torch_dtype=torch.float16 ) def train(): parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if model_args.dbg: training_args._n_gpu = 1 ENC_ID=model_args.mm_vision_select_layer model = LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) cfg = setup(model_args.config_file) if model_args.freeze_backbone: model.model.requires_grad_(False) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), tokenizer=tokenizer, model=model, ) if "llama" in model_args.model_name_or_path: tokenizer.add_special_tokens({ "eos_token": DEFAULT_EOS_TOKEN, "bos_token": DEFAULT_BOS_TOKEN, "unk_token": DEFAULT_UNK_TOKEN, }) if model_args.vision_tower is not None: model.config.mm_vision_tower = model_args.vision_tower from transformers import CLIPVisionModel from llava.train.image_processing_gptv import CLIPImageProcessor dtype = torch.float32 if training_args.fp16: dtype = torch.float16 if training_args.bf16: dtype = torch.bfloat16 openseed_vision = BaseModel(cfg, build_model(cfg)).cuda() # if not model_args.dbg: checkpoint = torch.load(model_args.os_weights, map_location='cpu') model_dict = openseed_vision.state_dict() pretrained_dict = {"model."+k: v for k, v in checkpoint.items() if "model."+k in model_dict} model_dict.update(pretrained_dict) openseed_vision.load_state_dict(model_dict) # openseed_vision.stat if not hasattr(model.model, 'vision_tower'): vision_tower = CLIPVisionModel.from_pretrained(model_args.vision_tower) else: vision_tower = model.model.vision_tower[0] image_processor = CLIPImageProcessor.from_pretrained(model_args.vision_tower) image_processor.size['shortest_edge'] = 800 vision_config = vision_tower.config vision_tower=openseed_vision vision_tower.config=vision_config # vision_tower.num_queries=300 # vision_tower.idx=model_args.enc_idx vision_tower.idx=ENC_ID vision_tower.num_enc_tokens=ENC_LENS[vision_tower.idx] vision_tower.dim_queries=256 # num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 num_patches=vision_tower.num_enc_tokens data_args.image_token_len = num_patches data_args.image_processor = image_processor data_args.is_multimodal = True vision_tower.requires_grad_(False) # model.model.vision_tower = vision_tower # HACK: for FSDP vision_tower.to(device=training_args.device) model.model.vision_tower = [vision_tower] model.config.use_mm_proj = True model.config.mm_hidden_size = vision_config.hidden_size=vision_tower.dim_queries model.config.mm_vision_select_layer = model_args.mm_vision_select_layer if not hasattr(model.model, 'mm_projector') or model.model.mm_projector.weight.shape[1]!=vision_config.hidden_size: mm_projector = nn.Linear(vision_config.hidden_size, model.config.hidden_size) model.model.mm_projector = mm_projector else: mm_projector = model.model.mm_projector if not hasattr(model.model, 'obj_projector'): obj_projector = nn.Linear(vision_config.hidden_size+4, model.config.hidden_size) model.model.obj_projector = obj_projector else: obj_projector = model.model.obj_projector if model_args.pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) if model_args.pretrain_obj_mlp_adapter is not None: obj_projector_weights = torch.load(model_args.pretrain_obj_mlp_adapter, map_location='cpu') obj_projector.load_state_dict({k.split('.')[-1]: v for k, v in obj_projector_weights.items()}) model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in mm_projector.parameters(): p.requires_grad = True for p in obj_projector.parameters(): p.requires_grad = True model.config.mm_use_im_start_end = model_args.mm_use_im_start_end data_args.mm_use_im_start_end = model_args.mm_use_im_start_end tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) vision_config.use_im_start_end = model_args.mm_use_im_start_end if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_OBJECT_START_TOKEN,DEFAULT_OBJECT_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: model.model.orig_embeds_params = [model.get_input_embeddings().weight.data.clone().to(device=training_args.device)] for p in model.get_input_embeddings().parameters(): p.requires_grad = True for p in model.get_output_embeddings().parameters(): p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert input_embeddings.shape == embed_tokens_weight.shape assert num_new_tokens == 2 input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/model/semsam/language/loss.py ================================================ import pickle from distutils import log import torch import torch.nn.functional as F import torch.distributed as dist from einops import rearrange, repeat from timm.loss import SoftTargetCrossEntropy soft_cross_entropy = SoftTargetCrossEntropy() def is_dist_initialized(): return torch.distributed.is_initialized() def get_world_size(): if is_dist_initialized(): return torch.distributed.get_world_size() return 1 def get_rank(): if is_dist_initialized(): return dist.get_rank() return 0 def all_gather_grad(x): if get_world_size() > 1: all_x = [torch.zeros_like(x) for _ in range(get_world_size())] torch.distributed.all_gather(all_x, x) all_x[torch.distributed.get_rank()] = x x = torch.cat(all_x, dim=0) return x def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1): """ Args: image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256 text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256 Returns: """ # [B, L1, C], L1 = 1 # image_feat = F.normalize(image_feat, dim=-1) # [B, L2, C] # text_feat = F.normalize(text_feat, dim=-1) # HACK: normalize outside # [B, L1, L2] dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l') # [B, L2, L1] dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l') batch = image_feat.shape[0] img_len = image_feat.shape[1] text_len = text_feat.shape[1] # [B, L1, L2] pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2') # [B, L2, L1] pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1') image_x = rearrange(image_feat, 'b l c -> (b l) c') text_x = rearrange(text_feat, 'b l c -> (b l) c') logits_per_img = image_x @ all_gather_grad(text_x).t() logits_per_text = text_x @ all_gather_grad(image_x).t() # get label globally # [B, L1, B, L2, W] labels_per_img = F.one_hot( torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(), num_classes=get_world_size()).to(image_x.dtype) labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat( torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1') # [BxL1, WxBxL2] labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)') # [B, L2, B, L1, W] labels_per_text = F.one_hot( torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(), num_classes=get_world_size()).to(text_x.dtype) labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat( torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1') # [BxL2, WxBxL1] labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)') logit_scale = temperature.exp().clamp(max=100) loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img) loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text) loss = 0.5 * (loss_img + loss_text) return loss def vl_contrastive_loss(image_feat, text_feat, temperature=1): # if image_id or text_id is None, it should be None across all GPUs # image_feat = F.normalize(image_feat, dim=1) # text_feat = F.normalize(text_feat, dim=1) # handle normalization outside # add the following 4 lines image_feat = all_gather_grad(image_feat) text_feat = all_gather_grad(text_feat) logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs def all_gather_pickle(data, device): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device) # obtain Tensor size of each rank local_size = torch.LongTensor([tensor.numel()]).cuda() size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda()) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def all_gather_arbitary_tensor(tensor): if get_world_size() > 1: device = tensor.device tensor_batch = all_gather_pickle(tensor.cpu(), device) tensor_batch = [x.to(device) for x in tensor_batch] tensor_batch[torch.distributed.get_rank()] = tensor tensor_batch = torch.cat(tensor_batch, dim=0) else: tensor_batch = tensor return tensor_batch def ql_contrastive_loss(image_feat, text_feat, temperature=1): # add the following 4 lines image_feat = all_gather_arbitary_tensor(image_feat) text_feat = all_gather_arbitary_tensor(text_feat) logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs def vl_similarity(image_feat, text_feat, temperature=1): # Only support single GPU for now. logits = torch.matmul(image_feat, text_feat.t()) logits = temperature.exp().clamp(max=100) * logits return logits def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1): # add the following 4 lines image_feat = all_gather_arbitary_tensor(image_feat) text_feat = all_gather_arbitary_tensor(text_feat) text_hash_batch = all_gather_pickle(text_hash, text_feat.device) text_hash_all = torch.cat(text_hash_batch) text_hash_all_unique = torch.unique(text_hash_all).tolist() gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device) text_hash_all = text_hash_all.tolist() text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique]) for idx, txt in enumerate(text_hash_all): gt[idx][text_hash_all_unique.index(txt)] = 1 logits = torch.matmul(image_feat, text_feat_unique.t()) logits = logits*temperature.exp().clamp(max=100) loss_img = soft_cross_entropy(logits, gt) loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True)) loss = 0.7 * loss_img + 0.3 * loss_text return loss def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training): # add the following 4 lines image_feat = all_gather_grad(image_feat_inp.contiguous()) text_feat = all_gather_grad(text_feat_inp.contiguous()) image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7) text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7) temperature = lang_enc.logit_scale logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs ================================================ FILE: llava/model/semsam/language/misc.py ================================================ import random import torch import nltk nltk.data.path.append('/mnt/data/nltk_data') import numpy as np from utils.constants import IMAGENET_DEFAULT_TEMPLATES def vl_similarity(image_feat, text_feat, temperature=1): # Only support single GPU for now. logits = torch.matmul(image_feat, text_feat.t()) logits = temperature.exp().clamp(max=100) * logits return logits def get_tag(tokenized, tags): if not isinstance(tags, (list, tuple)): tags = [tags] ret = [] for (word, pos) in nltk.pos_tag(tokenized): for tag in tags: if pos == tag: ret.append(word) return ret def get_noun_phrase(tokenized): # Taken from Su Nam Kim Paper... grammar = r""" NBAR: {*} # Nouns and Adjectives, terminated with Nouns NP: {} {