Repository: UX-Decoder/LLaVA-Grounding Branch: main Commit: 668b7cc3d536 Files: 257 Total size: 2.6 MB Directory structure: gitextract_hcvjbjkn/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── openseed/ │ │ ├── openseed_swint_lang_joint.yaml │ │ ├── openseed_swint_lang_joint_2st.yaml │ │ └── openseed_swint_lang_joint_2st_visual_prompt.yaml │ └── semsam/ │ └── visual_prompt_encoder.yaml ├── datasets_os/ │ ├── __init__.py │ ├── build.py │ ├── custom_dataset_dataloader.py │ ├── dataset_mappers/ │ │ ├── __init__.py │ │ ├── coco_instance_new_baseline_dataset_mapper.py │ │ ├── coco_instruct_grounding_dataset_interactive_mapper.py │ │ ├── coco_instruct_grounding_dataset_mapper.py │ │ ├── coco_interactive_panoptic_new_baseline_dataset_mapper.py │ │ ├── coco_panoptic_interactive_dataset_mapper.py │ │ ├── coco_panoptic_new_baseline_dataset_mapper.py │ │ ├── flickr_instance_new_baseline_dataset_mapper.py │ │ ├── flickr_instance_new_baseline_dataset_mapper_.py │ │ ├── flickr_instance_new_baseline_dataset_mapper_end.py │ │ ├── flickr_new_baseline_dataset_mapper.py │ │ ├── inference_mapper_with_gt.py │ │ ├── llava_dataset_mapper.py │ │ ├── refcoco_dataset_mapper.py │ │ └── vg_instance_new_baseline_dataset_mapper.py │ ├── refer.py │ ├── registration/ │ │ ├── __init__.py │ │ ├── register_coco_instruct_grounding_dataset.py │ │ ├── register_coco_panoptic_annos_grounding_interactive.py │ │ ├── register_flickr_dataset.py │ │ └── register_vg_dataset.py │ └── semseg_loader.py ├── docs/ │ └── MODEL_ZOO.md ├── gradio_demo/ │ ├── LLaVA_G_Demo.py │ └── __init__.py ├── llava/ │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── eval/ │ │ ├── LLaVA_G_Eval.py │ │ ├── eval_gpt_review.py │ │ ├── eval_gpt_review_bench.py │ │ ├── eval_gpt_review_visual.py │ │ ├── eval_gpt_review_visual2.py │ │ ├── eval_science_qa.py │ │ ├── eval_science_qa_gpt4.py │ │ ├── eval_science_qa_gpt4_requery.py │ │ ├── generate_webpage_data_from_table.py │ │ ├── llava_mapper.py │ │ ├── model_qa.py │ │ ├── model_vqa.py │ │ ├── model_vqa_science.py │ │ ├── qa_baseline_gpt35.py │ │ ├── run_llava.py │ │ ├── summarize_gpt_review.py │ │ └── webpage/ │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css │ ├── mm_utils.py │ ├── model/ │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model/ │ │ │ ├── llava_llama.py │ │ │ ├── llava_llama_gd.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt/ │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder/ │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── openseed/ │ │ │ ├── BaseModel.py │ │ │ ├── __init__.py │ │ │ ├── architectures/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── openseed_model.py │ │ │ │ ├── openseed_model_decouple_train.py │ │ │ │ └── registry.py │ │ │ ├── backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── focal.py │ │ │ │ ├── focal_dw.py │ │ │ │ ├── registry.py │ │ │ │ └── swin.py │ │ │ ├── body/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── decoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── modules.py │ │ │ │ │ ├── openseed_decoder.py │ │ │ │ │ ├── openseed_decoder_decouple.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dino_decoder.py │ │ │ │ │ └── utils.py │ │ │ │ ├── encoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── encoder_deform.py │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── functions/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ │ ├── make.sh │ │ │ │ │ │ ├── modules/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn.py │ │ │ │ │ │ ├── setup.py │ │ │ │ │ │ ├── src/ │ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ │ └── test.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer_encoder_fpn.py │ │ │ │ ├── openseed_head.py │ │ │ │ ├── registry.py │ │ │ │ └── transformer_blocks.py │ │ │ ├── language/ │ │ │ │ ├── LangEncoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── encoder.py │ │ │ │ ├── registry.py │ │ │ │ └── vlpencoder.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── criterion.py │ │ │ │ ├── matcher.py │ │ │ │ ├── point_features.py │ │ │ │ ├── position_encoding.py │ │ │ │ └── postprocessing.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── box_ops.py │ │ │ ├── config.py │ │ │ └── misc.py │ │ ├── semsam/ │ │ │ ├── BaseModel.py │ │ │ ├── __init__.py │ │ │ ├── architectures/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1.py │ │ │ │ └── registry.py │ │ │ ├── backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── focal.py │ │ │ │ ├── focal_dw.py │ │ │ │ ├── registry.py │ │ │ │ ├── swin.py │ │ │ │ └── swin_new.py │ │ │ ├── body/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── decoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── idino_decoder_no_iou_token_partwhole_all_llm.py │ │ │ │ │ ├── modules.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dino_decoder.py │ │ │ │ │ └── utils.py │ │ │ │ ├── encoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── encoder_deform.py │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── functions/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ │ ├── make.sh │ │ │ │ │ │ ├── modules/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── ms_deform_attn.py │ │ │ │ │ │ ├── setup.py │ │ │ │ │ │ ├── src/ │ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ │ └── test.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer_encoder_fpn.py │ │ │ │ ├── openseed_head.py │ │ │ │ ├── registry.py │ │ │ │ └── transformer_blocks.py │ │ │ ├── language/ │ │ │ │ ├── LangEncoder/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── build.py │ │ │ │ │ ├── registry.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ ├── encoder.py │ │ │ │ ├── fixencoder.py │ │ │ │ ├── llama_encoder.py │ │ │ │ ├── loss.py │ │ │ │ ├── misc.py │ │ │ │ ├── modeling_llama_os.py │ │ │ │ ├── registry.py │ │ │ │ └── vlpencoder.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── criterion_id_llm.py │ │ │ │ ├── hooks.py │ │ │ │ ├── matcher.py │ │ │ │ ├── point_features.py │ │ │ │ ├── position_encoding.py │ │ │ │ └── postprocessing.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── box_ops.py │ │ │ ├── config.py │ │ │ └── misc.py │ │ └── utils.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── gradio_web_server.py │ │ ├── register_worker.py │ │ └── test_message.py │ ├── train/ │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── llava_trainer_gd.py │ │ ├── llava_trainer_joint_train.py │ │ ├── train.py │ │ ├── train_grounding_1st.py │ │ ├── train_joint_1st.py │ │ ├── train_joint_2st.py │ │ ├── train_joint_2st_interactive_refcoco_coco_instruction.py │ │ └── train_mem.py │ └── utils.py ├── pyproject.toml ├── scripts/ │ ├── convert_sqa_to_llava.py │ ├── convert_sqa_to_llava_base_prompt.py │ ├── finetune.sh │ ├── finetune_visual_prompt.sh │ ├── merge_lora_weights.py │ └── pretrain_joint.sh └── utils/ ├── Config.py ├── __init__.py ├── arguments.py ├── constants.py ├── constants_ori.py ├── dist.py ├── distributed.py ├── misc.py ├── model.py ├── nms.py ├── prompt_engineering.py ├── utils.py └── visualizer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.err *.out *.pyc wandb /data_preparation/vis_results/ /data_preparation/vis_results_new/ /LLAVA_Stage1_Pretrained/ /work_dirs/ /llava.egg-info/ /data_preparation/data/ /vis_results/ model_worker* /playground/ *.jsonl *.pth gradio_demo/tmp_files llava_bench_results symmary_results eval_gpt4 vis_results_pdf_precision vis_results_pdf_recall output/ datasets/ output datasets *.log *.json __pycache__/ */__pycache__ */*/__pycache__ */*/*/__pycache__ */*/*/*/__pycache__ gradio_demo/examples/*.mp4 ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ 🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models ======== [[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)] [[Demo](https://llava-grounding.deepdataspace.com/ )] [[Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md)] ## :fire: News [2024/1/14] Our training code is released. [2023/12/6] Our paper is available in arxiv. ## Contents - [🌋 LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models](#-llava-grounding-grounded-visual-chat-with-large-multimodal-models) - [:fire: News](#fire-news) - [Contents](#contents) - [Install](#install) - [LLaVA-Grounding Weights](#llava-grounding-weights) - [Demo](#demo) - [Training data](#training-data) - [Flickr30k](#flickr30k) - [COCO](#coco) - [LLaVA](#llava) - [Training](#training) - [Citation](#citation) ### Install 1. Clone this repository and navigate to LLaVA-Grounding fold: ```shell git clone https://github.com/UX-Decoder/LLaVA-Grounding.git cd LLaVA-Grounding ``` 2. Install required packages: ``` conda create -n llava python=3.10 -y conda activate llava pip install --upgrade pip # enable PEP 660 support pip install -e . ``` 3. Install additional packages for training cases ``` pip install -e ".[train]" pip install flash-attn --no-build-isolation ``` 4. Install packages necessary for [OpenSeeD](https://github.com/IDEA-Research/OpenSeeD) and [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM). ### LLaVA-Grounding Weights Please check out our [Model Zoo](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/docs/MODEL_ZOO.md) for all public LLaVA-Grounding checkpoints, and the instructions on how to use the weights. ### Demo After downloading model weights, simply conduct the following commends to run demo on your own machine. ```shell CUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg path_to_vision_cfg --path_inter_cfg path_to_inter_cfg --model_path path_to_ckpt_dir # for example, after downloading weights into checkpoints/llava_grounding CUDA_VISIBLE_DEVICES=0 python gradio_demo/LLaVA_G_Demo.py --path_vision_cfg configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml --path_inter_cfg configs/semsam/visual_prompt_encoder.yaml --model_path checkpoints/llava_grounding ``` Please refer to our [Online Demo](https://llava-grounding.deepdataspace.com/) for the more detailed user's guidence. ### Training data ```text data ├── flickr30k_entities │ ├── train/ │ ├── val/ │ ├── annotations │ ├──final_flickr_separateGT_train.json │ ├──final_flickr_separateGT_val.json ├── coco │ ├── train2014/ │ ├── train2017/ │ ├── panoptic_train2017/ │ ├── panoptic_semseg_train2017/ │ ├── annotations │ │ ├──instances_train2017.json │ │ ├──instances_train2017_gvc.json │ │ ├──grounded_visual_chat_data.json │ │ ├──instances_train2014_filter.json │ │ ├──panoptic_train2017_filter.json │ │ ├──grounding_train2017.json ├── llava │ ├── annotations │ ├── cap600k_brackets_all.json │ ├── llava_instruct_150k.json │ ├── llava_instruct_150k_visual_prompt.json ``` #### Flickr30k Please refer to [MDETR's pre-processed flickr30k data](https://github.com/ashkamath/mdetr/blob/main/.github/flickr.md). #### COCO Please download coco train2014 and train2017 images and panoptic segmentation and semantic segmentation data. Other annoations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data). #### LLaVA The processed annotations can be downloaded [here](https://github.com/UX-Decoder/LLaVA-Grounding/releases/tag/train_data). ### Training Stage 1 ```shell bash scripts/pretrain_joint.py ``` Stage 2 ```shell bash scripts/finetune.py ``` Stage 3 ```shell bash scripts/finetune_visual_prompt.py ``` ### Citation If you find LLaVA-Grounding useful for your research and applications, please cite using this BibTeX: ```bibtex @misc{zhang2023llavagrounding, title={LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models}, author={Hao Zhang and Hongyang Li and Feng Li and Tianhe Ren and Xueyan Zou and Shilong Liu and Shijia Huang and Jianfeng Gao and Lei Zhang and Chunyuan Li and Jianwei Yang}, year={2023}, booktitle={arXiv} } @misc{liu2023llava, title={Visual Instruction Tuning}, author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae}, publisher={arXiv:2304.08485}, year={2023} } ``` ================================================ FILE: configs/openseed/openseed_swint_lang_joint.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["flickr_train","coco_2017_train_panoptic_ref_full_with_sem_seg_caption_grounding"] TEST: ["flickr_val"] CLASS_CONCAT: false SIZE_DIVISIBILITY: 32 PROPOSAL_FILES_TRAIN: [] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/openseed/openseed_swint_lang_joint_2st.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 detach_seg: False VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True COCO_ONLY: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 LLM_WEIGHT: 1.0 WEIGHT_MULTIPLIER: 1.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_instruct_train_v3","flickr_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_instruct: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_instruct" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 detach_seg: False VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' clip: true # misc LOADER: JOINT: True KEY_DATASET: 'flickr' # model MODEL: NAME: openseed_model HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 4096 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: '' TEXT: ARCH: encoder NAME: transformer TOKENIZER: clip CONTEXT_LENGTH: 18 # 18 WIDTH: 512 HEADS: 8 LAYERS: 12 AUTOGRESSIVE: True BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 133 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: openseed_decoder TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True COCO_ONLY: True GROUNDING: ENABLED: False MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: False PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: False STEP: 50 RETRIEVAL: ENABLED: False DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 LLM_WEIGHT: 1.0 WEIGHT_MULTIPLIER: 1.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 300 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: True INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: True LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_ref_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_interactive_refcoco","coco_interactive","flickr_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 640 MAX_SIZE_TRAIN: 2560 MAX_SIZE_TEST: 2560 MASK_FORMAT: "polygon" CROP: ENABLED: True TYPE: "absolute" SIZE: [640, 640] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: 640 # used in dataset mapper DATASET_MAPPER_NAME: "mask_former_panoptic" FORMAT: "RGB" DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 16 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.1 MAX_SCALE: 2.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_instruct: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_instruct" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True coco_interactive: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: configs/semsam/visual_prompt_encoder.yaml ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- ################## # Task settings ################## WEIGHT: '' PORT: 53711 VERBOSE: true #OUTPUT_DIR: '../../data/output/test' inference_only: true OUTPUT_DIR: '../../data/output/test' # misc LOADER: JOINT: True KEY_DATASET: 'coco' # model MODEL: NAME: idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1 HEAD: openseed_head MASK_ON: false KEYPOINT_ON: false LOAD_PROPOSALS: false DIM_PROJ: 512 BACKBONE_DIM: 768 BACKGROUND: False WEIGHTS: None LLAMA: model_name_or_path: '/comp_robot/liushilong/data/LLAVA/LLAVA_7b' cache_dir: None model_max_length: 2048 hidden_size: 4096 tune_mm_mlp_adapter: True im_width: 16 load_fp16: False lora_r: 0 lora_alpha: 16 lora_dropout: 0.05 TEXT: ARCH: llama_encoder BACKBONE: NAME: swin PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' LOAD_PRETRAINED: true SWIN: PRETRAIN_IMG_SIZE: 224 PATCH_SIZE: 4 EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 MLP_RATIO: 4.0 QKV_BIAS: true QK_SCALE: ~ DROP_RATE: 0.0 ATTN_DROP_RATE: 0.0 DROP_PATH_RATE: 0.3 APE: false PATCH_NORM: true USE_CHECKPOINT: false OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ] ENCODER: NAME: encoder_deform IGNORE_VALUE: 255 NUM_CLASSES: 1 LOSS_WEIGHT: 1.0 CONVS_DIM: 256 MASK_DIM: 256 NORM: "GN" IN_FEATURES: [ "res2", "res3", "res4", "res5" ] DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] COMMON_STRIDE: 4 TRANSFORMER_ENC_LAYERS: 6 TOTAL_NUM_FEATURE_LEVELS: 4 NUM_FEATURE_LEVELS: 3 FEATURE_ORDER: "low2high" DECODER: NAME: idino_decoder_no_iou_token_partwhole_all_llm TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" MASK: True BOX: True PART: True pretrain: True match_loss: True GROUNDING: ENABLED: True MAX_LEN: 5 TEXT_WEIGHT: 2.0 CLASS_WEIGHT: 0.5 CAPTION: ENABLED: True PHRASE_PROB: 0.0 SIM_THRES: 0.95 CAPTIONING: ENABLED: True STEP: 50 RETRIEVAL: ENABLED: True DIM_IMG: 768 ENSEMBLE: True OPENIMAGE: ENABLED: False NEGATIVE_SAMPLES: 5 GROUNDING: ENABLED: False MAX_LEN: 5 DEEP_SUPERVISION: True NO_OBJECT_WEIGHT: 0.1 CLASS_WEIGHT: 4.0 MASK_WEIGHT: 5.0 DICE_WEIGHT: 5.0 BOX_WEIGHT: 5.0 GIOU_WEIGHT: 2.0 IOU_WEIGHT: 1.0 LLAMA_WEIGHT: 5.0 llama_det_weight: 2.0 llama_ref_weight: 1.0 llama_region_cap_weight: 1.0 llama_img_cap_weight: 1.0 llama_gd_weight: 20.0 llama_gd_text_weight: 2.0 REFER_WEIGHT: 5.0 COST_CLASS_WEIGHT: 4.0 COST_DICE_WEIGHT: 5.0 COST_MASK_WEIGHT: 5.0 COST_BOX_WEIGHT: 5.0 COST_GIOU_WEIGHT: 2.0 HIDDEN_DIM: 256 NUM_OBJECT_QUERIES: 0 NHEADS: 8 DROPOUT: 0.0 DIM_FEEDFORWARD: 2048 ENC_LAYERS: 0 PRE_NORM: False ENFORCE_INPUT_PROJ: False SIZE_DIVISIBILITY: 32 DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query TRAIN_NUM_POINTS: 12544 OVERSAMPLE_RATIO: 3.0 IMPORTANCE_SAMPLE_RATIO: 0.75 TWO_STAGE: False INITIALIZE_BOX_TYPE: 'no' DN: seg DN_NOISE_SCALE: 0.4 DN_NUM: 100 INITIAL_PRED: False LEARN_TGT: False TOTAL_NUM_FEATURE_LEVELS: 4 SEMANTIC_CE_LOSS: False PANO_BOX_LOSS: False COCO: True O365: False SAM: True PASCAL: True RE_POINT: True NUM_INTERACTIVE_TOKENS: 3 TEST: SEMANTIC_ON: True INSTANCE_ON: True PANOPTIC_ON: True OVERLAP_THRESHOLD: 0.8 OBJECT_MASK_THRESHOLD: 0.25 SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false TEST_FOUCUS_ON_BOX: False PANO_TRANSFORM_EVAL: True PANO_TEMPERATURE: 0.06 TEST: EVAL_PERIOD: 500000 PRECISE_BN: NUM_ITER: 1 ENABLED: False AUG: ENABLED: False SAM: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 0.99 MAX_SCALE: 1.01 DATASET_MAPPER_NAME: "sam" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'sam' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 4 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True COCO: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True VLP: INPUT: IMAGE_SIZE: 224 DATASET_MAPPER_NAME: "vlpretrain" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TRAIN: BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 TEST: BATCH_SIZE_TOTAL: 256 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] DATASETS: TRAIN: ["coco_2017_train_panoptic_filtrefgumdval_with_sem_seg_caption_grounding","mapillary_vistas_panoptic_train","ade20k_panoptic_train","sam_train","pascal_part_train","paco_train","partimagenet_train"]#,"sam_train","pascal_part_train"]#,"paco_train","partimagenet_train"] DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 16 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TraziningSampler" ASPECT_RATIO_GROUPING: True # Detectron2 training config for optimizer and lr scheduler SOLVER: BASE_LR_END: 0.0 MOMENTUM: 0.9 NESTEROV: False CHECKPOINT_PERIOD: 5000 IMS_PER_BATCH: 1 REFERENCE_WORLD_SIZE: 0 BIAS_LR_FACTOR: 1.0 WEIGHT_DECAY_BIAS: None # original BASE_LR: 0.0001 STEPS: [327778, 355092] MAX_ITER: 368750 GAMMA: 0.1 WARMUP_FACTOR: 1.0 WARMUP_ITERS: 10 WARMUP_METHOD: "linear" WEIGHT_DECAY: 0.05 OPTIMIZER: "ADAMW" LR_SCHEDULER_NAME: "WarmupMultiStepLR" LR_MULTIPLIER: backbone: 0.1 lang_encoder: 0.1 WEIGHT_DECAY_NORM: 0.0 WEIGHT_DECAY_EMBED: 0.0 CLIP_GRADIENTS: ENABLED: True CLIP_TYPE: "full_model" CLIP_VALUE: 0.01 NORM_TYPE: 2.0 AMP: ENABLED: True # Evaluation Dataset ADE20K: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "coco_interactive_panoptic_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True DATASET: DATASET: 'ade' TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 8 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 8 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True REF: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 FORMAT: "RGB" DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SUN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 SCAN: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 512 MAX_SIZE_TEST: 1024 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 BDD: INPUT: PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 0 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: False TEST: BATCH_SIZE_TOTAL: 8 CITY: INPUT: MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ] MIN_SIZE_TRAIN_SAMPLING: "choice" MIN_SIZE_TEST: 1024 MAX_SIZE_TRAIN: 4096 MAX_SIZE_TEST: 2048 CROP: ENABLED: True TYPE: "absolute" SIZE: [ 512, 1024 ] SINGLE_CATEGORY_MAX_AREA: 1.0 IGNORE_VALUE: 255 COLOR_AUG_SSD: True SIZE_DIVISIBILITY: -1 FORMAT: "RGB" DATASET_MAPPER_NAME: "mask_former_panoptic" MASK_FORMAT: "polygon" TEST: EVAL_PERIOD: 5000 BATCH_SIZE_TOTAL: 1 AUG: ENABLED: False MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ] MAX_SIZE: 4096 FLIP: True DATALOADER: FILTER_EMPTY_ANNOTATIONS: True NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True TRAIN: ASPECT_RATIO_GROUPING: true BATCH_SIZE_TOTAL: 2 BATCH_SIZE_PER_GPU: 2 SHUFFLE: true PSACAL_PART: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "pascal_part_lsj" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True MODEL: MASK_ON: True KEYPOINT_ON: False LOAD_PROPOSALS: False # DATASET: # DATASET: 'coco' TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 8 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True llava: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "llava" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True flickr: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "flickr" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True part: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "part" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True vg: INPUT: MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1333 IMAGE_SIZE: 1024 MIN_SCALE: 1.0 MAX_SCALE: 1.0 DATASET_MAPPER_NAME: "vg" IGNORE_VALUE: 255 COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "horizontal" MASK_FORMAT: "polygon" FORMAT: "RGB" CROP: ENABLED: True TEST: DETECTIONS_PER_IMAGE: 100 NAME: coco_eval IOU_TYPE: ['bbox', 'segm'] USE_MULTISCALE: false BATCH_SIZE_TOTAL: 1 MODEL_FILE: '' AUG: ENABLED: False TRAIN: BATCH_SIZE_TOTAL: 1 BATCH_SIZE_PER_GPU: 1 SHUFFLE: true DATALOADER: FILTER_EMPTY_ANNOTATIONS: False NUM_WORKERS: 2 LOAD_PROPOSALS: False SAMPLER_TRAIN: "TrainingSampler" ASPECT_RATIO_GROUPING: True ================================================ FILE: datasets_os/__init__.py ================================================ from . import registration from .build import * ================================================ FILE: datasets_os/build.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import os import itertools import logging import copy from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.utils.data import torch.utils.data as torchdata import detectron2.utils.comm as comm from detectron2.data.build import ( build_batch_data_loader, load_proposals_into_dataset, trivial_batch_collator, ) from detectron2.data import MetadataCatalog from detectron2.data.catalog import DatasetCatalog from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.samplers import InferenceSampler, TrainingSampler from fvcore.common.config import CfgNode from omegaconf import DictConfig, OmegaConf from .dataset_mappers import ( COCOPanopticInteractiveDatasetMapper, FlickrNewBaselineDatasetMapper, VGNewBaselineDatasetMapper, COCOInstructGroundingDatasetMapper, COCOInterGroundingDatasetMapper, ) from .custom_dataset_dataloader import build_custom_test_loader from llava.model.openseed.utils import configurable from detectron2.utils.comm import get_world_size, is_main_process from typing import Any, Dict, List, Set class JointLoader(torchdata.IterableDataset): def __init__(self, loaders, key_dataset): dataset_names = [] for key, loader in loaders.items(): name = "{}".format(key.split('_')[0]) # name = "{}".format(key) setattr(self, name, loader) dataset_names += [name] self.dataset_names = dataset_names self.key_dataset = key_dataset def __iter__(self): for batch in zip(*[getattr(self, name) for name in self.dataset_names]): yield {key: batch[i] for i, key in enumerate(self.dataset_names)} def __len__(self): return len(getattr(self, self.key_dataset)) def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names): """ Filter out images with none annotations or only crowd annotations (i.e., images without non-crowd annotations). A common training-time preprocessing on COCO dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format, but filtered. """ num_before = len(dataset_dicts) def valid(anns): for ann in anns: if isinstance(ann, list): for instance in ann: if instance.get("iscrowd", 0) == 0: return True else: if ann.get("iscrowd", 0) == 0: return True return False dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with no usable annotations. {} images left.".format( num_before - num_after, num_after ) ) return dataset_dicts def get_detection_dataset_dicts( dataset_names, filter_empty=True, proposal_files=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: dataset_names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `dataset_names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(dataset_names, str): dataset_names = [dataset_names] assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(dataset_names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names) assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) return dataset_dicts def _test_loader_from_config(cfg, dataset_name, mapper=None): """ Uses the given `dataset_name` argument (instead of the names in cfg), because the standard practice is to evaluate each test set individually (not combining them). """ if isinstance(dataset_name, str): dataset_name = [dataset_name] dataset = get_detection_dataset_dicts( dataset_name, filter_empty=False, proposal_files=None, ) # import ipdb;ipdb.set_trace() if mapper is None: if isinstance(cfg, (DictConfig)): cfg = OmegaConf.to_container(copy.deepcopy(cfg)) mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']}) mapper = DatasetMapper(mapper_cfg, False) assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number" batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size() return { "dataset": dataset, "mapper": mapper, "num_workers": cfg['DATALOADER']['NUM_WORKERS'], "sampler": InferenceSampler(len(dataset)), "batch_size": batch_size, } @configurable(from_config=_test_loader_from_config) def build_detection_test_loader( dataset: Union[List[Any], torchdata.Dataset], *, mapper: Callable[[Dict[str, Any]], Any], sampler: Optional[torchdata.Sampler] = None, batch_size: int = 1, num_workers: int = 0, collate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> torchdata.DataLoader: """ Similar to `build_detection_train_loader`, with default batch size = 1, and sampler = :class:`InferenceSampler`. This sampler coordinates all workers to produce the exact set of all samples. Args: dataset: a list of dataset dicts, or a pytorch dataset (either map-style or iterable). They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper: a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. sampler: a sampler that produces indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, which splits the dataset across all workers. Sampler must be None if `dataset` is iterable. batch_size: the batch size of the data loader to be created. Default to 1 image per worker since this is the standard when reporting inference time in papers. num_workers: number of parallel data loading workers collate_fn: same as the argument of `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of data. Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. Examples: :: data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test") """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if isinstance(dataset, torchdata.IterableDataset): assert sampler is None, "sampler must be None if dataset is IterableDataset" else: if sampler is None: sampler = InferenceSampler(len(dataset)) return torchdata.DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=False, num_workers=num_workers, collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, ) def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None): cfg_datasets = cfg['DATASETS'] cfg_dataloader = cfg['DATALOADER'] if dataset is None: dataset = get_detection_dataset_dicts( dataset_name, filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'], proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is None: sampler_name = cfg_dataloader['SAMPLER_TRAIN'] logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) sampler = TrainingSampler(len(dataset)) return { "dataset": dataset, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'], "aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'], "num_workers": cfg_dataloader['NUM_WORKERS'], } @configurable(from_config=_train_loader_from_config) def build_detection_train_loader( dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0 ): """ Build a dataloader for object detection with some default features. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, which coordinates a random shuffle sequence across all workers. total_batch_size (int): total batch size across all workers. Batching simply puts data into a list. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: torch.utils.data.DataLoader: a dataloader. Each output from it is a ``list[mapped_element]`` of length ``total_batch_size / num_workers``, where ``mapped_element`` is produced by the ``mapper``. """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def get_config_from_name(cfg, dataset_name): # adjust config according to dataset if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'coco_interactive' in dataset_name: cfg.update(cfg['coco_interactive']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." def build_train_dataloader(cfg,tokenizer=None,data_args=None,preprocess=None,llava_cap_loader=None ): dataset_names = cfg['DATASETS']['TRAIN'] loaders = {} cfg = copy.deepcopy(cfg) for dataset_name in dataset_names: cfg = get_config_from_name(cfg, dataset_name) mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME'] if mapper_name =="flickr": mapper=FlickrNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['flickr'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="coco_instruct": mapper=COCOInstructGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['coco_instruct'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="coco_interactive": if "refcoco" in dataset_name: refcoco=True else: refcoco=False mapper=COCOInterGroundingDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess,refcoco=refcoco) if refcoco: loaders['interactiveref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) else: loaders['interactive'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name =="vg": mapper=VGNewBaselineDatasetMapper(cfg,True,tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['vg'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) elif mapper_name == "coco_ref_panoptic_lsj": mapper = COCOPanopticInteractiveDatasetMapper(cfg, cfg.get('Train',True),tokenizer=tokenizer,data_args=data_args,preprocess=preprocess) loaders['refcoco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) else: mapper = None loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper) if llava_cap_loader is not None: loaders['llava_cap'] = llava_cap_loader if len(loaders) == 1 and not cfg['LOADER'].get('JOINT', False): for k, v in loaders.items(): print("number of iterations per epoch: ", v, len(loaders[k])) return list(loaders.values())[0] else: return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco')) ================================================ FILE: datasets_os/custom_dataset_dataloader.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License) import copy import logging import numpy as np import operator from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.utils.data as torchdata import json from detectron2.utils.comm import get_world_size from detectron2.utils.logger import _log_api_usage, log_first_n from detectron2.config import configurable from detectron2.data import samplers from torch.utils.data.sampler import BatchSampler, Sampler from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler, InferenceSampler from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram from detectron2.data.build import filter_images_with_only_crowd_annotations from detectron2.data.build import filter_images_with_few_keypoints from detectron2.data.build import check_metadata_consistency from detectron2.data.catalog import MetadataCatalog, DatasetCatalog from detectron2.utils import comm import itertools import math from collections import defaultdict from typing import Optional logger = logging.getLogger('detectron2.vlpart.data.custom_dataset_dataloader') def _custom_test_loader_from_config(cfg, dataset_name, mapper=None): if isinstance(dataset_name, str): dataset_name = [dataset_name] dataset = get_detection_dataset_dicts( dataset_name, filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name ] if cfg.MODEL.LOAD_PROPOSALS_TEST else None, ) if mapper is None: mapper = DatasetMapper(cfg, False) return { "dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS, "sampler": InferenceSampler(len(dataset)) if not isinstance(dataset, torchdata.IterableDataset) else None, } @configurable(from_config=_custom_test_loader_from_config) def build_custom_test_loader( dataset: Union[List[Any], torchdata.Dataset], *, mapper: Callable[[Dict[str, Any]], Any], sampler: Optional[torchdata.Sampler] = None, batch_size: int = 1, num_workers: int = 0, collate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> torchdata.DataLoader: if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if isinstance(dataset, torchdata.IterableDataset): assert sampler is None, "sampler must be None if dataset is IterableDataset" else: if sampler is None: sampler = InferenceSampler(len(dataset)) return torchdata.DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=False, num_workers=num_workers, collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, ) def trivial_batch_collator(batch): return batch def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): sampler_name = cfg.DATALOADER.SAMPLER_TRAIN if 'MultiDataset' in sampler_name: dataset_dicts = get_detection_dataset_dicts_with_source( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) else: dataset_dicts = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is not None: pass elif sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) elif sampler_name == "MultiDatasetSampler": sampler = MultiDatasetSampler( dataset_dicts, dataset_ratio = cfg.DATALOADER.DATASET_RATIO, use_rfs = cfg.DATALOADER.USE_RFS, dataset_ann = cfg.DATALOADER.DATASET_ANN, repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD, ) elif sampler_name == "RepeatFactorTrainingSampler": repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD ) sampler = RepeatFactorTrainingSampler(repeat_factors) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) return { "dataset": dataset_dicts, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, "num_workers": cfg.DATALOADER.NUM_WORKERS, 'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING, 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE, 'dataset_bs': cfg.DATALOADER.DATASET_BS, 'num_datasets': len(cfg.DATASETS.TRAIN) } @configurable(from_config=_custom_train_loader_from_config) def build_custom_train_loader( dataset, *, mapper, sampler, total_batch_size=16, aspect_ratio_grouping=True, num_workers=0, num_datasets=1, multi_dataset_grouping=False, use_diff_bs_size=False, dataset_bs=[] ): """ Modified from detectron2.data.build.build_custom_train_loader, but supports different samplers """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) if multi_dataset_grouping: return build_multi_dataset_batch_data_loader( use_diff_bs_size, dataset_bs, dataset, sampler, total_batch_size, num_datasets=num_datasets, num_workers=num_workers, ) else: return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def build_multi_dataset_batch_data_loader( use_diff_bs_size, dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0 ): """ """ world_size = get_world_size() assert ( total_batch_size > 0 and total_batch_size % world_size == 0 ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( total_batch_size, world_size ) batch_size = total_batch_size // world_size data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, num_workers=num_workers, batch_sampler=None, collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements worker_init_fn=worker_init_reset_seed, ) # yield individual mapped dict if use_diff_bs_size: return DIFFMDAspectRatioGroupedDataset( data_loader, dataset_bs, num_datasets) else: return MDAspectRatioGroupedDataset( data_loader, batch_size, num_datasets) def get_detection_dataset_dicts_with_source( dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None ): assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) for source_id, (dataset_name, dicts) in \ enumerate(zip(dataset_names, dataset_dicts)): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) for d in dicts: d['dataset_source'] = source_id if "annotations" in dicts[0]: try: class_names = MetadataCatalog.get(dataset_name).thing_classes check_metadata_consistency("thing_classes", dataset_name) print_instances_class_histogram(dicts, class_names) except AttributeError: # class names are not available for this dataset pass assert proposal_files is None dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) if min_keypoints > 0 and has_instances: dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) return dataset_dicts class MultiDatasetSampler(Sampler): def __init__( self, dataset_dicts, dataset_ratio, use_rfs, dataset_ann, repeat_threshold=0.001, seed: Optional[int] = None, ): """ """ sizes = [0 for _ in range(len(dataset_ratio))] for d in dataset_dicts: sizes[d['dataset_source']] += 1 logger.info('dataset sizes {}'.format(sizes)) self.sizes = sizes assert len(dataset_ratio) == len(sizes), \ 'length of dataset ratio {} should be equal to number if dataset {}'.format( len(dataset_ratio), len(sizes) ) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.dataset_ids = torch.tensor( [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \ for i, (r, s) in enumerate(zip(dataset_ratio, sizes))] dataset_weight = torch.cat(dataset_weight) rfs_factors = [] st = 0 for i, s in enumerate(sizes): if use_rfs[i]: if dataset_ann[i] == 'box': rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency else: rfs_func = repeat_factors_from_tag_frequency rfs_factor = rfs_func( dataset_dicts[st: st + s], repeat_thresh=repeat_threshold) rfs_factor = rfs_factor * (s / rfs_factor.sum()) else: rfs_factor = torch.ones(s) rfs_factors.append(rfs_factor) st = st + s rfs_factors = torch.cat(rfs_factors) self.weights = dataset_weight * rfs_factors self.sample_epoch_size = len(self.weights) def __iter__(self): start = self._rank yield from itertools.islice( self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) while True: ids = torch.multinomial( self.weights, self.sample_epoch_size, generator=g, replacement=True) nums = [(self.dataset_ids[ids] == i).sum().int().item() \ for i in range(len(self.sizes))] yield from ids class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): def __init__(self, dataset, batch_size, num_datasets): """ """ self.dataset = dataset self.batch_size = batch_size self._buckets = [[] for _ in range(2 * num_datasets)] def __iter__(self): for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id bucket = self._buckets[bucket_id] bucket.append(d) if len(bucket) == self.batch_size: yield bucket[:] del bucket[:] class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): def __init__(self, dataset, batch_sizes, num_datasets): """ """ self.dataset = dataset self.batch_sizes = batch_sizes self._buckets = [[] for _ in range(2 * num_datasets)] def __iter__(self): for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id bucket = self._buckets[bucket_id] bucket.append(d) if len(bucket) == self.batch_sizes[d['dataset_source']]: yield bucket[:] del bucket[:] def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh): """ """ category_freq = defaultdict(int) for dataset_dict in dataset_dicts: cat_ids = dataset_dict['pos_category_ids'] for cat_id in cat_ids: category_freq[cat_id] += 1 num_images = len(dataset_dicts) for k, v in category_freq.items(): category_freq[k] = v / num_images category_rep = { cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) for cat_id, cat_freq in category_freq.items() } rep_factors = [] for dataset_dict in dataset_dicts: cat_ids = dataset_dict['pos_category_ids'] rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) rep_factors.append(rep_factor) return torch.tensor(rep_factors, dtype=torch.float32) ================================================ FILE: datasets_os/dataset_mappers/__init__.py ================================================ from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper from .flickr_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as FlickrNewBaselineDatasetMapper from .coco_instruct_grounding_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInstructGroundingDatasetMapper from .coco_instruct_grounding_dataset_interactive_mapper import COCOInstanceNewBaselineDatasetMapper as COCOInterGroundingDatasetMapper from .vg_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper as VGNewBaselineDatasetMapper ================================================ FILE: datasets_os/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] if cfg_input['RANDOM_FLIP'] != "none": augmentation.append( T.RandomFlip( horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", vertical=cfg_input['RANDOM_FLIP'] == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: # Let's always keep mask # if not self.mask_on: # anno.pop("segmentation", None) anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects instances = utils.annotations_to_instances(annos, image_shape) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_instruct_grounding_dataset_interactive_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import random import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, refcoco=None, max_sampled=5, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.refcoco=refcoco self.max_sampled=max_sampled @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None,refcoco=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "refcoco":refcoco, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) if self.refcoco: max_sampled=min(self.max_sampled,num_conversations) sample_num=random.randint(1,max_sampled) sampled_convs=random.sample(dataset_dict['conversations'], k=sample_num) grounding_list=[] selected_conversation=[] sampled_convs[0][0][0]['value']='\n'+sampled_convs[0][0][0]['value'] for conv,gd in sampled_convs: grounding_list.extend(gd) conv[1]['value']=random.choice(conv[1]['value']) selected_conversation.extend(conv) else: rd = np.random.choice(num_conversations) selected_conversation, grounding_list = dataset_dict['conversations'][rd] dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer # num_segs = sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= assert "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0 anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id # id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor num_objs=(data_dict_conversation['input_ids']==1273).sum() grounding_list=[gd for gd in grounding_list if gd is not None] merged_grounding_list=sum(grounding_list,[]) # assert num_objs==len(merged_grounding_list) if num_objslen(merged_grounding_list): merged_grounding_list=merged_grounding_list+[merged_grounding_list[-1]]*(num_objs-len(merged_grounding_list)) merged_grounding_list=[anno_id2id[annid] for annid in merged_grounding_list] dataset_dict['grounding_index']=merged_grounding_list dataset_dict["instances"] = instances # if grounding_list is None: # dataset_dict['grounding']=False # grounding_mask=[False for _ in range(num_segs)] # dataset_dict['grounding_mask']=grounding_mask # else: # grounding_mask=[True if g is not None else False for g in grounding_list] # dataset_dict['grounding_mask']=grounding_mask # new_grounding_list=[g for g in grounding_list if g is not None] # if sum(grounding_mask)==0: # dataset_dict['grounding']=False # else: # dataset_dict['grounding']=True # if dataset_dict['grounding']: # # assert num_segs == len(grounding_list) # for grounding_id,grounding in enumerate(new_grounding_list): # if grounding is not None: # for annid in grounding: # id2class[anno_id2id[annid]].append(grounding_id) # # instances.gt_classes=id2class # dataset_dict["instances"] = instances # else: # dataset_dict['grounding'] = False # grounding_mask = [False for _ in range(num_segs)] # dataset_dict['grounding_mask'] = grounding_mask return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_instruct_grounding_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] suffix=[ "Please also provide the boxes and masks for the noun phrases in the response." , "Kindly ensure that the response includes the relevant boxes and masks for each noun phrase." , "Additionally, include the boxes and masks that match each noun phrase in the response." , "Please provide the boxes and masks that correspond to every noun phrase in your response." , "It’s important to have the boxes and masks that align with each noun phrase in the response." , "Make sure to include the appropriate boxes and masks for each noun phrase in your response." , "In your response, include the boxes and masks that pertain to each noun phrase." , "Also, supply the boxes and masks that are linked to each noun phrase in the response." , "Additionally, please furnish the boxes and masks that correspond to each noun phrase in the response." , "Don’t forget to provide the boxes and masks associated with each noun phrase in your response." , "Ensure that each noun phrase in the response has its respective boxes and masks.", ] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, replace_suffix=False, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.replace_suffix=replace_suffix @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "replace_suffix": cfg['MODEL'].get('REPLACE_SUFFIX', False), } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) selected_conversation, grounding_list = dataset_dict['conversations'][rd] dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) if self.replace_suffix: for conv in sources[0]: sf=np.random.choice(suffix) if conv['from'] == 'human': conv['value'] = conv['value'].replace('(with grounding)', sf, 1) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer num_segs = sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= if "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor if grounding_list is None: dataset_dict['grounding']=False grounding_mask=[False for _ in range(num_segs)] dataset_dict['grounding_mask']=grounding_mask else: grounding_mask=[True if g is not None else False for g in grounding_list] dataset_dict['grounding_mask']=grounding_mask new_grounding_list=[g for g in grounding_list if g is not None] if sum(grounding_mask)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: # assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(new_grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances else: dataset_dict['grounding'] = False grounding_mask = [False for _ in range(num_segs)] dataset_dict['grounding_mask'] = grounding_mask return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_interactive_panoptic_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOInteractivePanopticNewBaselineDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] if cfg.INPUT.RANDOM_FLIP != "none": augmentation.append( T.RandomFlip( horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", vertical=cfg.INPUT.RANDOM_FLIP == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInteractivePanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = filter_empty_instances_by_box(instances) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_panoptic_interactive_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import random import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances, BoxMode from detectron2.structures.boxes import pairwise_iou from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.data import MetadataCatalog from pycocotools import mask as coco_mask from utils.prompt_engineering import prompt_engineering, get_prompt_templates from llava.model.openseed.utils import configurable # from ..shapes.sampler import build_shape_sampler __all__ = ["COCOPanopticInteractiveDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ # assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks # This is specifically designed for the COCO dataset. class COCOPanopticInteractiveDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, caption_thres, # lvis, # lvis_thres, max_grounding_num, tokenizer, data_args, preprocess, # shape_sampler, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train self.caption_thres = caption_thres self.grounding = True # self.lvis = lvis # self.lvis_thres = lvis_thres self.max_grounding_num = max_grounding_num self.caption_similarity = torch.load(MetadataCatalog.get('logistic').get('caption_similarity_pth')) self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess # self.shape_sampler = shape_sampler @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) # shape_sampler = build_shape_sampler(cfg) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'], "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, # "shape_sampler": shape_sampler, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) # is_things = [COCO_CATEGORIES[idx]['isthing'] for idx in classes] classes = np.array(classes) # is_things = np.array(is_things) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) # instances.is_things = torch.tensor(is_things, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) masks = BitMasks(torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))) instances.gt_masks = masks instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks instances.gt_boxes = masks.get_bounding_boxes() if self.grounding: grounding_anno = dataset_dict['grounding_info'] if self.is_train: grounding_len = random.randint(1, self.max_grounding_num - 1) else: grounding_len = 1 if len(grounding_anno) > 0: masks_grd = [] texts_grd = [] mode = 'text' random.shuffle(grounding_anno) for ann in grounding_anno: rle = coco_mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = coco_mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2)>0 m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(m[:, :, None])[:, :, 0]==1 masks_grd += [m] # random select a sentence of a single annotation. rand_index = random.randint(0, len(ann['sentences']) - 1) texts_grd += [ann['sentences'][rand_index]['raw'].lower()] max_len = min(grounding_len, len(texts_grd)) indices = np.random.permutation(max_len) texts_grd = list(np.array(texts_grd)[indices]) masks_grd = torch.tensor(np.stack(masks_grd)[indices]) hash_grd = np.array([hash(txt) for txt in texts_grd]) gt_classes = list(range(len(texts_grd))) gt_classes = [[lb] for lb in gt_classes] label_set=texts_grd else: assert self.is_train masks_grd = instances.gt_masks.tensor mode = 'class' assert len(masks_grd) > 0 texts_grd = np.array([COCO_CATEGORIES[idx]['name'] for idx in classes]) hash_grd = np.array([hash(txt) for txt in texts_grd]) unique_hash_grd = np.unique(hash_grd) np.random.shuffle(unique_hash_grd) max_len = min(grounding_len,len(unique_hash_grd)) indices = np.random.permutation(max_len) selected_unique_hash_grd = unique_hash_grd[indices] selected_mask = np.in1d(hash_grd, selected_unique_hash_grd) texts_grd = texts_grd[selected_mask] hash_grd = hash_grd[selected_mask] masks_grd = masks_grd[selected_mask] texts_grd = [ text.replace('-other', '').replace('-merged', '').replace('-stuff', '') for text in texts_grd] label_set=list(set(texts_grd)) gt_classes=[[label_set.index(lb)] for lb in texts_grd] instances_gd = Instances(image_shape) instances_gd.gt_masks = BitMasks(masks_grd) instances_gd.gt_boxes = BitMasks(masks_grd).get_bounding_boxes() instances_gd.gt_masks=instances_gd.gt_masks.tensor instances_gd.gt_classes=gt_classes dataset_dict["instances"] = instances_gd conversations=[] for i in range(len(label_set)): if i==0: question={'from': 'human', 'value': f"\n Please detect the object according to the text {label_set[i]} (referring)."} else: question={'from': 'human', 'value': f"Please detect the object according to the text {label_set[i]} (referring)."} answer={'from': 'gpt', 'value': ' .'} conversations.append(question) conversations.append(answer) dataset_dict['conversation'] = [conversations] data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOPanopticNewBaselineDatasetMapper"] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ # assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] # if cfg.INPUT.RANDOM_FLIP != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", # vertical=cfg.INPUT.RANDOM_FLIP == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOPanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, gd_mode="inter", ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess self.gd_mode= gd_mode @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, "gd_mode": cfg.flickr.get("gd_mode", "inter"), } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor span_set = dict() end_dict = dict() gt_classes= [] for i, info in enumerate(dataset_dict['grounding_info']): gt_classes.append([]) # if len(info['tokens_positive'])>1: # print("multi class") for j in range(len(info['tokens_positive'])): if info['tokens_positive'][j][0] in span_set: span_set[info['tokens_positive'][j][0]].append(i) else: span_set[info['tokens_positive'][j][0]] = [i] if info['tokens_positive'][j][0] in end_dict: assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1] else: end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1] gt_classes[-1].append(info['tokens_positive'][j][0]) end_dict = sorted(end_dict.items()) start2id = dict() for i, (s, e) in enumerate(end_dict): start2id[s] = i gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes] instances.gt_classes = gt_classes dataset_dict["instances"] = instances # span_list = sorted(span_set.items()) # for k, v in span_set: # for i in range(len(v)): # v[i] = positive_new_ids[v[i]] cap_pieces = [] last_e = 0 for s, e in end_dict: cap_pieces.append(dataset_dict['caption'][last_e:s]) cap_pieces.append(dataset_dict['caption'][s:e]) last_e = e cap_pieces.append(dataset_dict['caption'][last_e:]) new_cap = [] if 'end' in self.gd_mode: k=1 for i, piece in enumerate(cap_pieces): if i % 2 == 1: if self.gd_mode == 'end': piece = '' + piece + '' else: assert self.gd_mode == 'end_num' piece = f' {k} ' + piece + '' k+=1 new_cap.append(piece) new_cap = "".join(new_cap) tail = [f'{i + 1}: ' for i in range(new_cap.count(""))] tail = '; '.join(tail) new_cap += f' {tail}.' else: for i, piece in enumerate(cap_pieces): if i % 2 == 1: piece = '' + piece + '' new_cap.append(piece) new_cap = "".join(new_cap) # gt_ids = [] # for s, e in end_dict: # if len(span_set[s]) > 1: # return dataset_dict # gt_ids.append(span_set[s][0] + 1) # ground_annos = dict() # ground_annos['gt_ids'] = gt_ids # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids] # ground_annos['caption'] = new_cap question={'from': 'human', 'value': "\nPresent a compact description of the photo's key features.\n(with grounding)"} answer={'from': 'gpt', 'value': new_cap} dataset_dict['conversation'] = [[question, answer]] # sources = preprocess_multimodal( # copy.deepcopy(dataset_dict['conversation']), # self.data_args) data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, data_args ): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) if "grounding_info" in dataset_dict: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) selected_conversation, grounding_list=dataset_dict['conversations'][rd] if grounding_list is None: dataset_dict['grounding']=False else: non_none=[1 for g in grounding_list if g is not None] if len(non_none)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: num_segs = sum([conv['value'].count('') for conv in selected_conversation]) assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances dataset_dict['conversation'] = [selected_conversation] sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), self.data_args) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_instance_new_baseline_dataset_mapper_end.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor span_set = dict() end_dict = dict() gt_classes= [] for i, info in enumerate(dataset_dict['grounding_info']): gt_classes.append([]) # if len(info['tokens_positive'])>1: # print("multi class") for j in range(len(info['tokens_positive'])): if info['tokens_positive'][j][0] in span_set: span_set[info['tokens_positive'][j][0]].append(i) else: span_set[info['tokens_positive'][j][0]] = [i] if info['tokens_positive'][j][0] in end_dict: assert end_dict[info['tokens_positive'][j][0]] == info['tokens_positive'][j][1] else: end_dict[info['tokens_positive'][j][0]] = info['tokens_positive'][j][1] gt_classes[-1].append(info['tokens_positive'][j][0]) end_dict = sorted(end_dict.items()) start2id = dict() for i, (s, e) in enumerate(end_dict): start2id[s] = i gt_classes= [[start2id[s] for s in gt_class] for gt_class in gt_classes] instances.gt_classes = gt_classes dataset_dict["instances"] = instances # span_list = sorted(span_set.items()) # for k, v in span_set: # for i in range(len(v)): # v[i] = positive_new_ids[v[i]] cap_pieces = [] last_e = 0 for s, e in end_dict: cap_pieces.append(dataset_dict['caption'][last_e:s]) cap_pieces.append(dataset_dict['caption'][s:e]) last_e = e cap_pieces.append(dataset_dict['caption'][last_e:]) new_cap = [] for i, piece in enumerate(cap_pieces): if i % 2 == 1: piece = '' + piece + '' new_cap.append(piece) new_cap = "".join(new_cap) tail = [f'{i + 1}: ' for i in range(new_cap.count(""))] tail = '; '.join(tail) new_cap += f' {tail}.' # gt_ids = [] # for s, e in end_dict: # if len(span_set[s]) > 1: # return dataset_dict # gt_ids.append(span_set[s][0] + 1) # ground_annos = dict() # ground_annos['gt_ids'] = gt_ids # ground_annos['gt_anno_ids'] = [dataset_dict['grounding_info'][gt_id_ - 1]['id'] for gt_id_ in gt_ids] # ground_annos['caption'] = new_cap question={'from': 'human', 'value': "\nPresent a compact description of the photo's key features.\n(with grounding)"} answer={'from': 'gpt', 'value': new_cap} dataset_dict['conversation'] = [[question, answer]] # sources = preprocess_multimodal( # copy.deepcopy(dataset_dict['conversation']), # self.data_args) data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/flickr_new_baseline_dataset_mapper.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Boxes, Instances __all__ = ["COCOInteractivePanopticNewBaselineDatasetMapper"] def filter_empty_instances_by_box( instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False ): assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) # TODO: can also filter visible keypoints if not r: return instances m = r[0] for x in r[1:]: m = m & x if return_mask: return instances[m], m return instances[m] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" image_size = cfg.INPUT.IMAGE_SIZE min_scale = cfg.INPUT.MIN_SCALE max_scale = cfg.INPUT.MAX_SCALE augmentation = [] if cfg.INPUT.RANDOM_FLIP != "none": augmentation.append( T.RandomFlip( horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", vertical=cfg.INPUT.RANDOM_FLIP == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInteractivePanopticNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( str(self.tfm_gens) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w dataset_dict["image_ori"]=image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = filter_empty_instances_by_box(instances) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/inference_mapper_with_gt.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import copy import logging import numpy as np from typing import List, Optional, Union import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.structures import BitMasks, Boxes, Instances """ This file contains the default mapping that's applied to "dataset dicts". """ __all__ = ["CoCoInferenceDatasetMapper"] class CoCoInferenceDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model. This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See :doc:`/tutorials/data_loading` for details. The callable currently does the following: 1. Read the image from "file_name" 2. Applies cropping/geometric transforms to the image and annotations 3. Prepare data and annotations to Tensor and :class:`Instances` """ @configurable def __init__( self, is_train: bool, *, augmentations: List[Union[T.Augmentation, T.Transform]], image_format: str, use_instance_mask: bool = False, use_keypoint: bool = False, instance_mask_format: str = "polygon", keypoint_hflip_indices: Optional[np.ndarray] = None, precomputed_proposal_topk: Optional[int] = None, recompute_boxes: bool = False, ): """ NOTE: this interface is experimental. Args: is_train: whether it's used in training or inference augmentations: a list of augmentations or deterministic transforms to apply image_format: an image format supported by :func:`detection_utils.read_image`. use_instance_mask: whether to process instance segmentation annotations, if available use_keypoint: whether to process keypoint annotations if available instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation masks into this format. keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` precomputed_proposal_topk: if given, will load pre-computed proposals from dataset_dict and keep the top k proposals for each image. recompute_boxes: whether to overwrite bounding box annotations by computing tight bounding boxes from instance mask annotations. """ if recompute_boxes: assert use_instance_mask, "recompute_boxes requires instance masks" # fmt: off self.is_train = is_train self.augmentations = T.AugmentationList(augmentations) self.image_format = image_format self.use_instance_mask = use_instance_mask self.instance_mask_format = instance_mask_format self.use_keypoint = use_keypoint self.keypoint_hflip_indices = keypoint_hflip_indices self.proposal_topk = precomputed_proposal_topk self.recompute_boxes = recompute_boxes # fmt: on logger = logging.getLogger(__name__) mode = "training" if is_train else "inference" logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") @classmethod def from_config(cls, cfg, is_train: bool = True): augs = utils.build_augmentation(cfg, is_train) if cfg.INPUT.CROP.ENABLED and is_train: augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) recompute_boxes = cfg.MODEL.MASK_ON else: recompute_boxes = False ret = { "is_train": is_train, "augmentations": augs, "image_format": cfg.INPUT.FORMAT, "use_instance_mask": cfg.MODEL.MASK_ON, "instance_mask_format": cfg.INPUT.MASK_FORMAT, "use_keypoint": cfg.MODEL.KEYPOINT_ON, "recompute_boxes": recompute_boxes, } if cfg.MODEL.KEYPOINT_ON: ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) if cfg.MODEL.LOAD_PROPOSALS: ret["precomputed_proposal_topk"] = ( cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN if is_train else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST ) return ret def _transform_annotations(self, dataset_dict, transforms, image_shape): # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: if not self.use_instance_mask: anno.pop("segmentation", None) if not self.use_keypoint: anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations( obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices ) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances( annos, image_shape, mask_format=self.instance_mask_format ) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. if self.recompute_boxes: instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # USER: Write your own image loading if it's not from a file image = utils.read_image(dataset_dict["file_name"], format=self.image_format) utils.check_image_size(dataset_dict, image) # USER: Remove if you don't do semantic/panoptic segmentation. if "sem_seg_file_name" in dataset_dict: sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) else: sem_seg_gt = None aug_input = T.AugInput(image, sem_seg=sem_seg_gt) transforms = self.augmentations(aug_input) image, sem_seg_gt = aug_input.image, aug_input.sem_seg image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if sem_seg_gt is not None: dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) # USER: Remove if you don't use pre-computed proposals. # Most users would not need this feature. if self.proposal_topk is not None: utils.transform_proposals( dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk ) # if not self.is_train: # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # dataset_dict.pop("sem_seg_file_name", None) # return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) instances.gt_boxes = Boxes(torch.zeros((0, 4))) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor instances.gt_boxes = masks.get_bounding_boxes() dataset_dict["instances"] = instances # dataset_dict["instances"] = filter_empty_instances_by_box(instances) if "annotations" in dataset_dict: self._transform_annotations(dataset_dict, transforms, image_shape) return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/llava_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) dataset_dict["image_ori"]=image image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict if "grounding_info" in dataset_dict: for obj in dataset_dict['grounding_info']: obj["bbox_mode"] = BoxMode.XYWH_ABS obj['tokens']=dataset_dict['caption'][obj['tokens_positive'][0][0]:obj['tokens_positive'][0][1]] # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor dataset_dict["instances"] = instances return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/refcoco_dataset_mapper.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. import copy import random import scipy.io import numpy as np import torch from PIL import Image from torchvision import transforms from pycocotools import mask from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from llava.model.openseed.utils import configurable __all__ = ["RefCOCODatasetMapper"] def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ assert is_train, "Only support training augmentation" cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] if cfg_input['RANDOM_FLIP'] != "none": augmentation.append( T.RandomFlip( horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", vertical=cfg_input['RANDOM_FLIP'] == "vertical", ) ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class RefCOCODatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, tfm_gens=None, image_format=None, min_size_test=None, max_size_test=None, mean=None, std=None, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens self.img_format = image_format self.is_train = is_train self.min_size_test = min_size_test self.max_size_test = max_size_test self.pixel_mean = torch.tensor(mean)[:,None,None] self.pixel_std = torch.tensor(std)[:,None,None] t = [] t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) self.transform = transforms.Compose(t) @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation if is_train: tfm_gens = build_transform_gen(cfg, is_train) else: tfm_gens = None ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT'].get('FORMAT', 'RGB'), "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], "mean": cfg['INPUT']['PIXEL_MEAN'], "std": cfg['INPUT']['PIXEL_STD'], } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below file_name = dataset_dict['file_name'] if self.is_train == False: image = Image.open(file_name).convert('RGB') dataset_dict['width'] = image.size[0] dataset_dict['height'] = image.size[1] image = self.transform(image) image = torch.from_numpy(np.asarray(image).copy()) dataset_dict["image_ori"] = image image = image.permute(2,0,1) dataset_dict['image'] = image grounding_anno = dataset_dict['grounding_info'] assert len(grounding_anno) > 0 masks_grd = [] texts_grd = [] boxes_grd = [] for ann in grounding_anno: rle = mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 masks_grd += [m] texts_grd.append([x['raw'].lower() for x in ann['sentences']]) boxes_grd.append(ann['bbox']) # xywh masks_grd = torch.from_numpy(np.stack(masks_grd)) boxes_grd = torch.tensor(boxes_grd) groundings = {'masks': masks_grd, 'texts': texts_grd, 'boxes': boxes_grd} dataset_dict["groundings"] = groundings else: image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"] = image image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) grounding_anno = dataset_dict['grounding_info'] assert len(grounding_anno) > 0 masks_grd = [] texts_grd = [] boxes_grd = [] hash_grd = [] for ann in grounding_anno: rle = mask.frPyObjects( ann['segmentation'], dataset_dict['height'], dataset_dict['width']) m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(m[:,:,None])[:,:,0] masks_grd += [m] rand_id = random.randint(0, len(ann['sentences'])-1) texts_grd.append(ann['sentences'][rand_id]['raw'].lower()) hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower())) masks_grd = torch.from_numpy(np.stack(masks_grd)) boxes_grd = torch.tensor(boxes_grd) groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'} dataset_dict["groundings"] = groundings return dataset_dict ================================================ FILE: datasets_os/dataset_mappers/vg_instance_new_baseline_dataset_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, max_grounding_num, tokenizer, data_args, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.max_grounding_num = max_grounding_num self.tokenizer = tokenizer self.processor = data_args.image_processor self.data_args = data_args self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,data_args=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'], "tokenizer": tokenizer, "data_args": data_args, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_clip = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean)) image_clip = self.processor.preprocess(image_clip, return_tensors='pt')['pixel_values'][0] else: image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) dataset_dict["image_ori"]=image image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # if not self.is_train: # # USER: Modify this if you want to keep them for some reason. # dataset_dict.pop("annotations", None) # return dataset_dict assert "annotations" in dataset_dict for obj in dataset_dict['annotations']: obj["bbox_mode"] = BoxMode.XYWH_ABS # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["annotations"] ] # NOTE: does not support BitMask due to augmentation # Current BitMask cannot handle empty objects assert len(annos)>0 # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") instances.captions=[ann['caption'] for ann in dataset_dict["annotations"]] # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Need to filter empty instances first (due to augmentation) # instances = utils.filter_empty_instances(instances) # Generate masks from polygon h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor # dataset_dict["instances"] = instances num_instances = len(instances) indices = list(range(num_instances)) import random if self.is_train: grounding_len = random.randint(1, self.max_grounding_num - 1) else: grounding_len = 1 random.shuffle(indices) indices = indices[:grounding_len] texts_grd = [instances.captions[i] for i in indices] gt_classes = list(range(len(texts_grd))) gt_classes = [[lb] for lb in gt_classes] label_set = texts_grd grounding_instances = Instances(image_size=(h, w)) grounding_instances.gt_boxes = instances.gt_boxes[indices] grounding_instances.gt_classes = gt_classes dataset_dict["instances"]=grounding_instances conversations=[] for i in range(len(label_set)): if i==0: question={'from': 'human', 'value': f"\n Please detect the object according to the text {label_set[i]} (referring)."} else: question={'from': 'human', 'value': f"Please detect the object according to the text {label_set[i]} (referring)."} answer={'from': 'gpt', 'value': ' .'} conversations.append(question) conversations.append(answer) dataset_dict['conversation'] = [conversations] data_dict_conversation = self.preprocess( dataset_dict['conversation'], self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer']=self.tokenizer return dataset_dict ================================================ FILE: datasets_os/refer.py ================================================ __author__ = 'licheng' """ This interface provides access to four datasets: 1) refclef 2) refcoco 3) refcoco+ 4) refcocog split by unc and google The following API functions are defined: REFER - REFER api class getRefIds - get ref ids that satisfy given filter conditions. getAnnIds - get ann ids that satisfy given filter conditions. getImgIds - get image ids that satisfy given filter conditions. getCatIds - get category ids that satisfy given filter conditions. loadRefs - load refs with the specified ref ids. loadAnns - load anns with the specified ann ids. loadImgs - load images with the specified image ids. loadCats - load category names with the specified category ids. getRefBox - get ref's bounding box [x, y, w, h] given the ref_id showRef - show image, segmentation or box of the referred object with the ref getMask - get mask and area of the referred object given ref showMask - show mask of the referred object given ref """ from doctest import REPORT_ONLY_FIRST_FAILURE import sys import os.path as osp import json import pickle import time import itertools import skimage.io as io import matplotlib.pyplot as plt from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon, Rectangle from pprint import pprint import numpy as np from pycocotools import mask # import cv2 # from skimage.measure import label, regionprops class REFER: def __init__(self, data_root, dataset='refcoco', splitBy='unc'): # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog # also provide dataset name and splitBy information # e.g., dataset = 'refcoco', splitBy = 'unc' print('loading dataset {} into memory...'.format(dataset)) self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) self.DATA_DIR = osp.join(data_root, dataset) if dataset in ['refcoco', 'refcoco+', 'refcocog']: self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') elif dataset == 'refclef': self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') else: print('No refer dataset is called [{}]'.format(dataset)) sys.exit() # load refs from data/dataset/refs(dataset).json tic = time.time() ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') self.data = {} self.data['dataset'] = dataset self.data['refs'] = pickle.load(open(ref_file, 'rb')) # load annotations from data/dataset/instances.json instances_file = osp.join(self.DATA_DIR, 'instances.json') instances = json.load(open(instances_file, 'r')) self.data['images'] = instances['images'] self.data['annotations'] = instances['annotations'] self.data['categories'] = instances['categories'] # create index self.createIndex() print('DONE (t=%.2fs)'.format(time.time()-tic)) def createIndex(self): # create sets of mapping # 1) Refs: {ref_id: ref} # 2) Anns: {ann_id: ann} # 3) Imgs: {image_id: image} # 4) Cats: {category_id: category_name} # 5) Sents: {sent_id: sent} # 6) imgToRefs: {image_id: refs} # 7) imgToAnns: {image_id: anns} # 8) refToAnn: {ref_id: ann} # 9) annToRef: {ann_id: ref} # 10) catToRefs: {category_id: refs} # 11) sentToRef: {sent_id: ref} # 12) sentToTokens: {sent_id: tokens} print('creating index...') # fetch info from instances Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} for ann in self.data['annotations']: Anns[ann['id']] = ann imgToAnns[ann['image_id']] = imgToAnns.get( ann['image_id'], []) + [ann] for img in self.data['images']: Imgs[img['id']] = img for cat in self.data['categories']: Cats[cat['id']] = cat['name'] # fetch info from refs Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} Sents, sentToRef, sentToTokens = {}, {}, {} for ref in self.data['refs']: # ids ref_id = ref['ref_id'] ann_id = ref['ann_id'] category_id = ref['category_id'] image_id = ref['image_id'] # add mapping related to ref Refs[ref_id] = ref imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] refToAnn[ref_id] = Anns[ann_id] annToRef[ann_id] = ref # add mapping of sent for sent in ref['sentences']: Sents[sent['sent_id']] = sent sentToRef[sent['sent_id']] = ref sentToTokens[sent['sent_id']] = sent['tokens'] # create class members self.Refs = Refs self.Anns = Anns self.Imgs = Imgs self.Cats = Cats self.Sents = Sents self.imgToRefs = imgToRefs self.imgToAnns = imgToAnns self.refToAnn = refToAnn self.annToRef = annToRef self.catToRefs = catToRefs self.sentToRef = sentToRef self.sentToTokens = sentToTokens print('index created.') def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): image_ids = image_ids if type(image_ids) == list else [image_ids] cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: refs = self.data['refs'] else: if not len(image_ids) == 0: refs = [self.imgToRefs[image_id] for image_id in image_ids] else: refs = self.data['refs'] if not len(cat_ids) == 0: refs = [ref for ref in refs if ref['category_id'] in cat_ids] if not len(ref_ids) == 0: refs = [ref for ref in refs if ref['ref_id'] in ref_ids] if not len(split) == 0: if split in ['testA', 'testB', 'testC']: # we also consider testAB, testBC, ... refs = [ref for ref in refs if split[-1] in ref['split']] elif split in ['testAB', 'testBC', 'testAC']: # rarely used I guess... refs = [ref for ref in refs if ref['split'] == split] elif split == 'test': refs = [ref for ref in refs if 'test' in ref['split']] elif split == 'train' or split == 'val': refs = [ref for ref in refs if ref['split'] == split] else: print('No such split [{}]'.format(split)) sys.exit() ref_ids = [ref['ref_id'] for ref in refs] return ref_ids def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): image_ids = image_ids if type(image_ids) == list else [image_ids] cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: ann_ids = [ann['id'] for ann in self.data['annotations']] else: if not len(image_ids) == 0: lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns] anns = list(itertools.chain.from_iterable(lists)) else: anns = self.data['annotations'] if not len(cat_ids) == 0: anns = [ann for ann in anns if ann['category_id'] in cat_ids] ann_ids = [ann['id'] for ann in anns] if not len(ref_ids) == 0: ids = set(ann_ids).intersection( set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) return ann_ids def getImgIds(self, ref_ids=[]): ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] if not len(ref_ids) == 0: image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) else: image_ids = self.Imgs.keys() return image_ids def getCatIds(self): return self.Cats.keys() def loadRefs(self, ref_ids=[]): if type(ref_ids) == list: return [self.Refs[ref_id] for ref_id in ref_ids] elif type(ref_ids) == int: return [self.Refs[ref_ids]] def loadAnns(self, ann_ids=[]): if type(ann_ids) == list: return [self.Anns[ann_id] for ann_id in ann_ids] elif type(ann_ids) == int or type(ann_ids) == unicode: return [self.Anns[ann_ids]] def loadImgs(self, image_ids=[]): if type(image_ids) == list: return [self.Imgs[image_id] for image_id in image_ids] elif type(image_ids) == int: return [self.Imgs[image_ids]] def loadCats(self, cat_ids=[]): if type(cat_ids) == list: return [self.Cats[cat_id] for cat_id in cat_ids] elif type(cat_ids) == int: return [self.Cats[cat_ids]] def getRefBox(self, ref_id): ref = self.Refs[ref_id] ann = self.refToAnn[ref_id] return ann['bbox'] # [x, y, w, h] def showRef(self, ref, seg_box='seg'): ax = plt.gca() # show image image = self.Imgs[ref['image_id']] I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) ax.imshow(I) # show refer expression for sid, sent in enumerate(ref['sentences']): print('{}. {}'.format(sid+1, sent['sent'])) # show segmentations if seg_box == 'seg': ann_id = ref['ann_id'] ann = self.Anns[ann_id] polygons = [] color = [] c = 'none' if type(ann['segmentation'][0]) == list: # polygon used for refcoco* for seg in ann['segmentation']: poly = np.array(seg).reshape((len(seg)/2, 2)) polygons.append(Polygon(poly, True, alpha=0.4)) color.append(c) p = PatchCollection(polygons, facecolors=color, edgecolors=( 1, 1, 0, 0), linewidths=3, alpha=1) ax.add_collection(p) # thick yellow polygon p = PatchCollection(polygons, facecolors=color, edgecolors=( 1, 0, 0, 0), linewidths=1, alpha=1) ax.add_collection(p) # thin red polygon else: # mask used for refclef rle = ann['segmentation'] m = mask.decode(rle) img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.array([2.0, 166.0, 101.0])/255 for i in range(3): img[:, :, i] = color_mask[i] ax.imshow(np.dstack((img, m*0.5))) # show bounding-box elif seg_box == 'box': ann_id = ref['ann_id'] ann = self.Anns[ann_id] bbox = self.getRefBox(ref['ref_id']) box_plot = Rectangle( (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) ax.add_patch(box_plot) def getMask(self, ref): # return mask, area and mask-center ann = self.refToAnn[ref['ref_id']] image = self.Imgs[ref['image_id']] if type(ann['segmentation'][0]) == list: # polygon rle = mask.frPyObjects( ann['segmentation'], image['height'], image['width']) else: rle = ann['segmentation'] m = mask.decode(rle) # sometimes there are multiple binary map (corresponding to multiple segs) m = np.sum(m, axis=2) m = m.astype(np.uint8) # convert to np.uint8 # compute area area = sum(mask.area(rle)) # should be close to ann['area'] return {'mask': m, 'area': area} # # position # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) # # mass position (if there were multiple regions, we use the largest one.) # label_m = label(m, connectivity=m.ndim) # regions = regionprops(label_m) # if len(regions) > 0: # largest_id = np.argmax(np.array([props.filled_area for props in regions])) # largest_props = regions[largest_id] # mass_y, mass_x = largest_props.centroid # else: # mass_x, mass_y = position_x, position_y # # if centroid is not in mask, we find the closest point to it from mask # if m[mass_y, mass_x] != 1: # print 'Finding closes mask point ...' # kernel = np.ones((10, 10),np.uint8) # me = cv2.erode(m, kernel, iterations = 1) # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style # points = np.array(points) # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) # id = np.argsort(dist)[0] # mass_y, mass_x = points[id] # # return # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} # # show image and mask # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) # plt.figure() # plt.imshow(I) # ax = plt.gca() # img = np.ones( (m.shape[0], m.shape[1], 3) ) # color_mask = np.array([2.0,166.0,101.0])/255 # for i in range(3): # img[:,:,i] = color_mask[i] # ax.imshow(np.dstack( (img, m*0.5) )) # plt.show() def showMask(self, ref): M = self.getMask(ref) msk = M['mask'] ax = plt.gca() ax.imshow(msk) if __name__ == '__main__': refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg', dataset='refcocog', splitBy='google') ref_ids = refer.getRefIds() print(len(ref_ids)) print(len(refer.Imgs)) print(len(refer.imgToRefs)) ref_ids = refer.getRefIds(split='train') print('There are {} training referred objects.' % len(ref_ids)) for ref_id in ref_ids: ref = refer.loadRefs(ref_id)[0] if len(ref['sentences']) < 2: continue pprint(ref) print('The label is {}.'.format(refer.Cats[ref['category_id']])) # plt.figure() # refer.showRef(ref, seg_box='box') # plt.show() # plt.figure() # refer.showMask(ref) # plt.show() ================================================ FILE: datasets_os/registration/__init__.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. from . import ( register_coco_panoptic_annos_grounding_interactive, register_coco_instruct_grounding_dataset, register_flickr_dataset, # register_vg_dataset, ) ================================================ FILE: datasets_os/registration/register_coco_instruct_grounding_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets import load_sem_seg from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.utils.file_io import PathManager import pycocotools.mask as mask_util _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { "coco_instruct_train_v3": ( "coco/train2014", # image_root "coco/annotations/instances_train2017_gvc.json", # annot_root "llava/annotations/grounded_visual_chat_data.json", ), "coco_interactive": ( "coco/train2014", # image_root "coco/annotations/instances_train2014_filter.json", # annot_root "llava/annotations/llava_instruct_150k_visual_prompt.json", ), "coco_interactive_refcoco": ( "coco/train2017", # image_root "coco/annotations/instances_train2017_refcoco.json", # annot_root "coco/annotations/grounding_train2017_instruct.json", ), } def get_metadata(): meta = {} return meta def load_coco_json(image_root, annot_json,conversation, metadata): """ Args: image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". Returns: list[dict]: a list of dicts in Detectron2 standard format. (See `Using Custom Datasets `_ ) """ with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) imgid2image = {} for image in json_info["images"]: image_id = image["id"] imgid2image[image_id] = image for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) segm = grd_ann.get("segmentation", None) if segm: # either list[list[float]] or dict(RLE) if isinstance(segm, dict): if isinstance(segm["counts"], list): # convert to compressed RLE segm = mask_util.frPyObjects(segm, *segm["size"]) grd_ann["segmentation"] = segm grd_dict[image_id].append(grd_ann) conv_dict = collections.defaultdict(list) with open(conversation) as f: data = json.load(f) for d in data: image_id = int(d['id']) if 'gd_ls' not in d: d['gd_ls']=None if 'q_gd_ls' in d: conv_dict[image_id].append((d['conversations'],d['q_gd_ls'])) else: conv_dict[image_id].append((d['conversations'], d['gd_ls'])) ret = [] for d in data: image_id = int(d['id']) image= imgid2image[image_id] image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] if image_id in conv_dict and len(conv_dict[image_id])>0: ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "conversations": conv_dict[image_id], } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_coco( name, metadata, image_root, annot_json,conversation): DatasetCatalog.register( name, lambda: load_coco_json(image_root, annot_json,conversation, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_coco(root): for ( prefix, (image_root, annot_root,conversation_path), ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): register_coco( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root, annot_root), conversation_path, ) _root = os.getenv("DATASET", "datasets") register_all_coco(_root) ================================================ FILE: datasets_os/registration/register_coco_panoptic_annos_grounding_interactive.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets import load_sem_seg from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { "coco_2017_train_panoptic_ref_full": ( # This is the original panoptic annotation directory "coco/panoptic_train2017", "coco/annotations/panoptic_train2017_filter.json", "coco/panoptic_semseg_train2017", "coco/annotations/grounding_train2017.json", ), } def get_metadata(): meta = {} # The following metadata maps contiguous id from [0, #thing categories + # #stuff categories) to their names and colors. We have to replica of the # same name and color under "thing_*" and "stuff_*" because the current # visualization function in D2 handles thing and class classes differently # due to some heuristic used in Panoptic FPN. We keep the same naming to # enable reusing existing visualization functions. thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] stuff_classes = [k["name"] for k in COCO_CATEGORIES] stuff_colors = [k["color"] for k in COCO_CATEGORIES] meta["thing_classes"] = thing_classes meta["thing_colors"] = thing_colors meta["stuff_classes"] = stuff_classes meta["stuff_colors"] = stuff_colors # Convert category id for training: # category id: like semantic segmentation, it is the class id for each # pixel. Since there are some classes not used in evaluation, the category # id is not always contiguous and thus we have two set of category ids: # - original category id: category id in the original dataset, mainly # used for evaluation. # - contiguous category id: [0, #classes), in order to train the linear # softmax classifier. thing_dataset_id_to_contiguous_id = {} stuff_dataset_id_to_contiguous_id = {} for i, cat in enumerate(COCO_CATEGORIES): if cat["isthing"]: thing_dataset_id_to_contiguous_id[cat["id"]] = i # else: # stuff_dataset_id_to_contiguous_id[cat["id"]] = i # in order to use sem_seg evaluator stuff_dataset_id_to_contiguous_id[cat["id"]] = i meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id return meta def load_coco_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, grounding_file, meta): """ Args: image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". Returns: list[dict]: a list of dicts in Detectron2 standard format. (See `Using Custom Datasets `_ ) """ def _convert_category_id(segment_info, meta): if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ segment_info["category_id"] ] segment_info["isthing"] = True else: segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ segment_info["category_id"] ] segment_info["isthing"] = False return segment_info with PathManager.open(json_file) as f: json_info = json.load(f) with PathManager.open(grounding_file) as f: grounding_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) for grd_ann in grounding_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for ann in json_info["annotations"]: image_id = int(ann["image_id"]) # TODO: currently we assume image and label has the same filename but # different extension, and images have extension ".jpg" for COCO. Need # to make image extension a user-provided argument if we extend this # function to support other COCO-like datasets. image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg") label_file = os.path.join(gt_dir, ann["file_name"]) sem_label_file = os.path.join(semseg_dir, ann["file_name"]) segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]] grounding_anno = grd_dict[image_id] if image_id in grd_dict else [] ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "pan_seg_file_name": label_file, "sem_seg_file_name": sem_label_file, "segments_info": segments_info, } ) assert len(ret), f"No images found in {image_dir}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"] assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"] return ret def register_coco_panoptic_annos_caption_grounding_sem_seg( name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, grounding_root, instances_json ): panoptic_name = '_'.join(name.split('_')[0:4]) delattr(MetadataCatalog.get(panoptic_name), "thing_classes") delattr(MetadataCatalog.get(panoptic_name), "thing_colors") MetadataCatalog.get(panoptic_name).set( thing_classes=metadata["thing_classes"], thing_colors=metadata["thing_colors"], # thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"], ) # the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg" semantic_name = name + "_with_sem_seg_caption_grounding" DatasetCatalog.register( semantic_name, lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, sem_seg_root, grounding_root, metadata), ) MetadataCatalog.get(semantic_name).set( sem_seg_root=sem_seg_root, panoptic_root=panoptic_root, image_root=image_root, panoptic_json=panoptic_json, json_file=instances_json, evaluator_type="coco_panoptic_seg_interactive", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_coco_panoptic_annos_caption_grounding_sem_seg(root): for ( prefix, (panoptic_root, panoptic_json, semantic_root, grounding_root), ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): prefix_instances = '_'.join(prefix.split('_')[0:3]) instances_meta = MetadataCatalog.get(prefix_instances) image_root, instances_json = instances_meta.image_root, instances_meta.json_file # image_root = image_root.replace('datasets', root) register_coco_panoptic_annos_caption_grounding_sem_seg( prefix, get_metadata(), image_root, os.path.join(root, panoptic_root), os.path.join(root, panoptic_json), os.path.join(root, semantic_root), os.path.join(root, grounding_root), os.path.join(root, instances_json), ) _root = os.getenv("DATASET", "datasets") register_all_coco_panoptic_annos_caption_grounding_sem_seg(_root) ================================================ FILE: datasets_os/registration/register_flickr_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS = { "flickr_val": ( "flickr30k_entities/val", # image_root "final_flickr_separateGT_val.json", # # anno_path ), "flickr_train": ( "flickr30k_entities/train", # image_root "final_flickr_separateGT_train.json", # # anno_path ), } def get_metadata(): meta = {} return meta def load_flickr_json(image_root, annot_json, metadata): with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) # for grd_ann in json_info['annotations']: # image_id = int(grd_ann["image_id"]) # grd_dict[image_id].append(grd_ann) for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for image in json_info["images"]: image_id = int(image["id"]) caption=image['caption'] image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] ret.append( { "file_name": image_file, "image_id": image_id, "grounding_info": grounding_anno, "caption": caption, } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_flickr( name, metadata, image_root, annot_json): DatasetCatalog.register( name, lambda: load_flickr_json(image_root, annot_json, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_flickr(root,anno_root): for ( prefix, (image_root, anno_path), ) in _PREDEFINED_SPLITS.items(): register_flickr( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root,anno_root, anno_path), ) _root = os.getenv("DATASET", "datasets") ann_root = os.getenv("Flickr", "flickr30k_entities/annotations") register_all_flickr(_root,ann_root) ================================================ FILE: datasets_os/registration/register_vg_dataset.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import json import os import collections from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.utils.file_io import PathManager _PREDEFINED_SPLITS = { "vg_train": ( "vg/images/", # image_root "train.json", # anno_path ), } def get_metadata(): meta = {} return meta def load_vg_json(image_root, annot_json, metadata): with PathManager.open(annot_json) as f: json_info = json.load(f) # build dictionary for grounding grd_dict = collections.defaultdict(list) for grd_ann in json_info['annotations']: image_id = int(grd_ann["image_id"]) grd_dict[image_id].append(grd_ann) ret = [] for image in json_info["images"]: image_id = int(image["id"]) image_file = os.path.join(image_root, image['file_name']) grounding_anno = grd_dict[image_id] ret.append( { "file_name": image_file, "image_id": image_id, "annotations": grounding_anno, } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def register_vg( name, metadata, image_root, annot_json): DatasetCatalog.register( name, lambda: load_vg_json(image_root, annot_json, metadata), ) MetadataCatalog.get(name).set( image_root=image_root, json_file=annot_json, evaluator_type="grounding_refcoco", ignore_label=255, label_divisor=1000, **metadata, ) def register_all_vg(root,anno_root): for ( prefix, (image_root, anno_path), ) in _PREDEFINED_SPLITS.items(): register_vg( prefix, get_metadata(), os.path.join(root, image_root), os.path.join(root,anno_root, anno_path), ) _root = os.getenv("DATASET", "datasets") anno_root = os.getenv("VG", "vg/annotations/") register_all_vg(_root,anno_root) ================================================ FILE: datasets_os/semseg_loader.py ================================================ from PIL import Image import scipy.io import numpy as np def load_semseg(filename, loader_type): if loader_type == 'PIL': semseg = np.array(Image.open(filename), dtype=np.int) elif loader_type == 'MAT': semseg = scipy.io.loadmat(filename)['LabelMap'] return semseg ================================================ FILE: docs/MODEL_ZOO.md ================================================ # LLaVA-Grounding Checkpoints We will continuously update the model zoo. | Model Name | LLM version | Model Config | Weights | |------------|:---------------:|:-------------:|:-----------:| | LLaVA_Grounding_v0_7b | vicuna-v0-7b | [[grounding-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml), [visual-prompt-module-cfg](https://github.com/UX-Decoder/LLaVA-Grounding/blob/main/configs/semsam/visual_prompt_encoder.yaml)]
(0.3B in total) | [HuggingFace](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp) | ================================================ FILE: gradio_demo/LLaVA_G_Demo.py ================================================ import gradio as gr import os import cv2 import torch import numpy as np from llava.eval.LLaVA_G_Eval import Evaluator_MM_Inter from llava import conversation as conversation_lib from llava.mm_utils import tokenizer_image_token from llava.constants import DEFAULT_IMAGE_TOKEN def get_image_name(dir_save="./gradio_demo/tmp_files", prefix="click_img_"): import os files = os.listdir(dir_save) file_orders = [int(file.split(".")[0][len(prefix):]) for file in files if file.endswith(".jpg") and file.startswith(prefix)] if len(file_orders) == 0: return os.path.join(dir_save, prefix + "0.jpg") else: return os.path.join(dir_save, prefix + str(max(file_orders) + 1) + ".jpg") def preprocess_multi_conv( sources, tokenizer, has_image = False ): conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} conv.messages = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conv_prompt = conv.get_prompt() conv_prompt = "ASSISTANT: ".join(conv_prompt.split("ASSISTANT: ")[:-1]) + "ASSISTANT:" conv_prompt = conv_prompt.replace("", "") conversations = [conv_prompt] print("Input Prompt: ", conv_prompt) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO return dict( input_ids=input_ids, labels=targets, ) def filter_empty_box_mask(text, boxes_image, masks_image): def extract_text(sentence): # Use regular expression to find and extract the text and number import re pattern = r"| " cleaned_text = re.sub(pattern, '', sentence) return cleaned_text if len(boxes_image) == 0: return text, boxes_image, masks_image else: sub_texts = text.split(" ") sub_texts_filtered = [] boxes_image_filtered = [] masks_image_filtered = [] for box_per_gd, mask_per_gd, text_per_gd in zip(boxes_image, masks_image, sub_texts): text_per_gd += " " ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0) if len(ind_nonempty_box[0]) < box_per_gd.shape[0]: # empty box encountered if len(ind_nonempty_box[0]) == 0: text_per_gd = " " + " ".join(extract_text(text_per_gd).split()) sub_texts_filtered.append(text_per_gd) # box is desperated continue else: box_per_gd = box_per_gd[ind_nonempty_box] mask_per_gd = mask_per_gd[ind_nonempty_box] boxes_image_filtered.append(box_per_gd) masks_image_filtered.append(mask_per_gd) sub_texts_filtered.append(text_per_gd) else: boxes_image_filtered.append(box_per_gd) masks_image_filtered.append(mask_per_gd) sub_texts_filtered.append(text_per_gd) sub_texts_filtered.append(sub_texts[-1]) text_filtered = "".join(sub_texts_filtered) return text_filtered, boxes_image_filtered, masks_image_filtered class InferenceDemo(object): def __init__(self, model_path, path_vision_cfg, path_inter_cfg, ) -> None: self.model_backend = Evaluator_MM_Inter( model_path=model_path, path_vision_model_cfg=path_vision_cfg, path_inter_model_cfg =path_inter_cfg, ) self.model_backend.data_mapper.preprocess = preprocess_multi_conv def hitory2datadict(self, history, text): def filter_valid_conversations(history): def delete_color(text): import re pattern = re.compile(r'(.*?)', re.DOTALL) clean_text = pattern.sub(r'\1', text) return clean_text valid_conversations = history[3:] valid_conversations = [aa for aa in valid_conversations if not (None in aa)] valid_conversations = [[delete_color(aa[0]), delete_color(aa[1])] for aa in valid_conversations] return valid_conversations valid_conversations = filter_valid_conversations(history) dataset_dict = { "file_name": history[1][0][0], "image_id": 0, "question_id": 0, } dataset_dict['conversations'] = [] for valid_conv in valid_conversations: conv = [ { "from": "human", "value": valid_conv[0] }, { "from": "gpt", "value": valid_conv[1] } ] dataset_dict['conversations'].append([conv, None]) conv = [ { "from": "human", "value": text }, { "from": "gpt", "value": "Placeholder." } ] dataset_dict['conversations'].append([conv, None]) dataset_dict['conversations'][0][0][0]["value"] = DEFAULT_IMAGE_TOKEN + " " + dataset_dict['conversations'][0][0][0]["value"] return dataset_dict def inference(self, data_dict): # TODO: Implement data_mapper. data_dict = self.model_backend.data_mapper(data_dict)[0] # device = self.model_backend.model.device for key, value in data_dict.items(): if isinstance(value, torch.Tensor): data_dict[key] = value.to(device) response_text, response_boxes, response_mask, mask_inter = self.model_backend.evaluate_sample([data_dict]) # response_text, response_boxes, response_mask = filter_empty_box_mask(response_text, response_boxes, response_mask) return response_text, response_boxes, response_mask, mask_inter def generate_distinct_colors(count): import colorsys import random random.seed(0) hues = [i/count for i in range(count)] random.shuffle(hues) colors = [] for hue in hues: rgb = colorsys.hsv_to_rgb(hue, 1, 1) rgb = tuple(int(val * 255) for val in rgb) colors.append(rgb) return colors def add_text(history, text, image, threshold_slider, temporature_slider, interaction_selector): # add a text to history stream. and leave the response as None for you to fill in bot. def response2stream(response, question): return [[question, response]] def post_process_text_response(text): def find_start_idxes(sentence, word): window_size = len(word) start_indexes = [] assert len(sentence) > window_size if sentence == window_size: return [0] for start_index in range(len(sentence) - window_size): if sentence[start_index: start_index + window_size] == word: start_indexes.append(start_index) return start_indexes def add_color_to_text(obj_id, text): color = colors[obj_id] text = f"{text}" return text def format_sentence(splitted_sentence): joint_sentence = " ".join(splitted_sentence) return joint_sentence def extract_text(sentence): import re pattern = r"|" cleaned_text = re.sub(pattern, '', sentence) return cleaned_text text_pure = "" seg_start_index = find_start_idxes(text, "") if len(seg_start_index) > 0: count_obj = 0 subtexts = text.split(" ") for subtext in subtexts: if "" in subtext: start_idx = find_start_idxes(subtext, "")[0] text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())]) text_ = extract_text(subtext[start_idx:]) text_pure += add_color_to_text(count_obj, text_) count_obj += 1 else: text_pure = format_sentence([text_pure, format_sentence(subtext.split())]) else: text_pure = text return text_pure def post_process_gd_response(path_ori_image, gd_results_per_image): def unresize_box(box, width, height): ratio = min(width, height) / max(width, height) if width > height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 1] = box[:, 1] / ratio box[:, 3] = box[:, 3] / ratio elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 0] = box[:, 0] / ratio box[:, 2] = box[:, 2] / ratio return box image = cv2.imread(path_ori_image) height, width = image.shape[:2] gd_results_per_image = [unresize_box(aa.detach().cpu(), width, height) for aa in gd_results_per_image] for gd_id, gd_result in enumerate(gd_results_per_image): bboxes = gd_result.cpu().tolist() for bbox in bboxes: bbox = [int(bbox[0]*width), int(bbox[1]*height), int(bbox[2]*width), int(bbox[3]*height)] cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), colors[gd_id][::-1], 2) path_save = get_image_name(prefix="grounding_img_") cv2.imwrite(path_save, image) return (path_save, ) def post_process_masks(path_ori_image, mask_inter, path_gd_image, masks_gd, loc_inter, inter_type): def unresize_mask(mask, width, height): import torch.nn.functional as F if width >= height: # then the height dimension is padded, the y coordinates should be divided by ratio mask = F.interpolate((mask[None, ...]).float(), size=[width, width], mode="nearest")[0] mask = mask[:, :height] elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio mask = F.interpolate((mask[None, ...]).float(), size=[height, height], mode="nearest")[0] mask = mask[:, :, :width] return mask def unnormalize_inter(mask, loc): height, width, _ = mask.shape loc_x_mean, loc_y_mean, loc_w, loc_h = loc if height >= width: loc_x_mean = loc_x_mean / (width/height) loc_w = loc_w / (width/height) else: loc_y_mean = loc_y_mean / (height/width) loc_h = loc_h / (height/width) return [loc_x_mean-loc_w/2, loc_y_mean-loc_h/2, loc_x_mean+loc_w/2, loc_y_mean+loc_h/2] image = cv2.imread(path_ori_image) gd_image = cv2.imread(path_gd_image) returns = [] if not (mask_inter is None): mask_ = (mask_inter[0][..., None] > 0).float().cpu().numpy() mask_ = cv2.resize(mask_, (max(image.shape[0], image.shape[1]), max(image.shape[0], image.shape[1]))) mask_ = mask_[:image.shape[0], :image.shape[1]] mask_ = mask_[..., None] * np.array([155, 155, 155])[None, None, :] image = (image * 0.5 + mask_ * 0.5).astype(np.uint8) if inter_type.lower() == "box": loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])] thickness = max(int(max(image.shape) / 1000 * 5), 1) print(thickness) cv2.rectangle(image, (int(loc_inter_unnormalized[0][0]*image.shape[1]), int(loc_inter_unnormalized[0][1]*image.shape[0])), (int((loc_inter_unnormalized[0][2])*image.shape[1]), int((loc_inter_unnormalized[0][3])*image.shape[0])), (255, 255, 255), thickness) elif inter_type.lower() == "click": loc_inter_unnormalized = [unnormalize_inter(mask_, loc_inter[0])] thickness = max(int(max(image.shape) / 1000 * 10), 1) print(thickness) cv2.circle(image, (int(((loc_inter_unnormalized[0][0] + loc_inter_unnormalized[0][2])/2)*image.shape[1]), int(((loc_inter_unnormalized[0][1] + loc_inter_unnormalized[0][3])/2)*image.shape[0])), thickness, (255, 255, 255), thickness=-1) path_save = get_image_name(prefix="seg_inter_") cv2.imwrite(path_save, image) returns.append((path_save, )) else: returns.append(None) if not (masks_gd is None): height, width = image.shape[:2] masks_gd = [unresize_mask(aa, width, height) for aa in masks_gd] colored_mask = torch.zeros((3, height, width), dtype=torch.long) for gd_id, gd_mask_result in enumerate(masks_gd): gd_mask_result = gd_mask_result.sum(dim=0, keepdim=True) colored_mask[:, gd_mask_result[0] > 0.5] = torch.tensor(colors[gd_id][::-1])[:, None] gd_image = (gd_image * 0.6 + colored_mask.permute(1,2,0).numpy() * 0.4).astype(np.uint8) path_save_gd = get_image_name(prefix="seg_gd_") cv2.imwrite(path_save_gd, gd_image) returns.append((path_save_gd, )) else: returns.append(None) return returns def mask2point(mask, inter_type): height, width = mask.shape[:2] ys, xs = np.where(mask[..., 0] == 255) if inter_type.lower() == "click": loc_x = xs.mean() loc_y = ys.mean() loc_x = loc_x / width loc_y = loc_y / height if height >= width: loc_x = loc_x * (width/height) else: loc_y = loc_y * (height/width) return torch.tensor([[loc_x, loc_y, 0.006, 0.006]]) elif inter_type.lower() == "box": loc_x_min = xs.min() / width loc_x_max = xs.max() / width loc_y_min = ys.min() / height loc_y_max = ys.max() / height if height >= width: loc_x_min = loc_x_min * (width/height) loc_x_max = loc_x_max * (width/height) else: loc_y_min = loc_y_min * (height/width) loc_y_max = loc_y_max * (height/width) width = loc_x_max - loc_x_min height = loc_y_max - loc_y_min return torch.tensor([[(loc_x_min + loc_x_max)/2, (loc_y_min + loc_y_max)/2, width, height]]) if len(history) < 3: response_text = "Please upload an image first." else: loc_inter = mask2point(image["mask"], interaction_selector) is_interactive = torch.isnan(loc_inter[0]).sum() == 0 if is_interactive: input_data_dict = our_chatbot.hitory2datadict(history, text) input_data_dict["points"] = loc_inter input_data_dict["mode_inter"] = interaction_selector else: input_data_dict = our_chatbot.hitory2datadict(history, text) input_data_dict["points"] = None input_data_dict["mode_inter"] = None input_data_dict["matching_threshold"] = threshold_slider input_data_dict["temporature"] = temporature_slider response_text, response_gd, response_mask, mask_inter = our_chatbot.inference(input_data_dict) response_msks = post_process_masks(history[1][0][0], mask_inter, history[1][0][0], response_mask, loc_inter, interaction_selector) if "" in response_text: response_gd = post_process_gd_response(response_msks[1][0], response_gd) response_msks[1] = list(response_msks[1]) response_msks[1][0] = response_gd[0] response_msks[1] = tuple(response_msks[1]) response_text = post_process_text_response(response_text) history += response2stream(response_text, text) for response_msk in response_msks: if not (response_msk is None): history += response2stream(response_msk, None) return history, None def add_image(history, image): print("LOG. Add Image Function is called.") path_input_img = get_image_name(prefix="tmp_input_img_") cv2.imwrite(path_input_img, image["image"][..., ::-1]) if len(history) > 0: history = [(None, "A new image recieved, I will clear the history conversations.")] else: history = [(None, None)] # just to align with the above one, to determin where the image_path is. history = history + [((path_input_img, ), None)] history = history + [(None, "Let't talk about this image!")] return history def add_interaction_click(history, image, interaction_selector): print("LOG. Add Interaction Function is called.") if interaction_selector.lower() == "box": history = history + [(None, "A more detailed box is specified, lets further talk about the region inside the box.")] elif interaction_selector.lower() == "click": history = history + [(None, "A more detailed click is specified, lets further talk about the region around the click.")] mask = image["mask"][..., :3] * np.array([234, 176, 113]) image_rgb = image["image"][..., ::-1] image_clicked = (image_rgb * 0.6 + mask * 0.4).astype(np.uint8) path_save = get_image_name(prefix="click_img_") cv2.imwrite(path_save, image_clicked) return history def bot(history): yield history def clear_history(history, txt, img): return None, None, None def clear_response(history): for index_conv in range(1, len(history)): # loop until get a text response from our model. conv = history[-index_conv] if not (conv[0] is None): break question = history[-index_conv][0] history = history[:-index_conv] return history, question def upvote_one(history): print("TODO: Implement upvote_one function.") pass def downvote_one(history): print("TODO: Implement downvote_one function.") pass def flag_one(history): print("TODO: Implement flag_one function.") pass #? defined here for later renderring. txt = gr.Textbox( scale=4, show_label=False, placeholder="Enter text and press enter, or upload an image. Append '(with grounding)' if you want to do grounding.", container=False, ) with gr.Blocks() as demo: # Informations title_markdown = (""" # LLaVA-Grounding: Grounded Visual Chat with Large Multimodal Models [[Project Page](https://llava-vl.github.io/llava-grounding)] [[Arxiv](https://arxiv.org/abs/2312.02949)] [[Demo](http://llava-grounding.xyzou.net:6084)] [[Model](https://huggingface.co/Haozhangcx/llava_grounding_gd_vp)] """) tips_markdown = (""" **Tips for better results** 1. Adjust 'Threshold' according to the results or change your expression and click 'Regenerate' may help get better results. 2. Set temporature to 0.0 get reproducible results. """) tos_markdown = (""" **Terms of use:** By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. """) learn_more_markdown = (""" **License:** The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) models = [ "LLaVA-Grounding-7B", ] interactions = [ "Click", "Box" ] cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Markdown(title_markdown) with gr.Row(): with gr.Column(min_width=300, scale=0.4): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) img = gr.Image( type="numpy", # label="Image", height=220, tool="sketch", interactive=True ) img_upload_btn = gr.Button("Submit Image") with gr.Row(): inter_upload_btn = gr.Button("Submit Interaction") interaction_selector = gr.Dropdown( choices=interactions, value=interactions[0] if len(interactions) > 0 else "", interactive=True, show_label=False, container=False, ) with gr.Row(): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature") threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.05, step=0.1, interactive=True, label="Threshold") # possibly a bug in gradio: https://github.com/gradio-app/gradio/issues/3623 gr.Examples(examples=[ [f"{cur_dir}/examples/meeting.jpg", "Describe the scene in detail. (with grounding)"], [f"{cur_dir}/examples/pizza.jpg", "Describe the scene in detail. (with grounding)"], ], inputs=[img, txt], label="Grounded Description Examples: ") gr.Examples(examples=[ [f"{cur_dir}/examples/cow_motor.jpg", "Where is the object % and what is it doing?"], [f"{cur_dir}/examples/dog_sleep.jpg", "What is the object % doing and why?"], ], inputs=[img, txt], label="Visual Prompt Examples (Please draw clicks or boxes on the woman and dog for the two examples, respectively.): ") with gr.Column(): chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, height=598 # avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))), ) with gr.Row(): with gr.Column(scale=8): txt.render() with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Send") #TODO: Enable these buttons. with gr.Row(): upvote_btn = gr.Button(value="👍 Upvote", interactive=True) downvote_btn = gr.Button(value="👎 Downvote", interactive=True) flag_btn = gr.Button(value="⚠️ Flag", interactive=True) #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) gr.Markdown(tips_markdown) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) if os.path.isfile("gradio_demo/examples/demo_grounding.mp4"): # only online demo gr.Markdown("-----------------------------------") gr.Markdown("## User's Guidance") with gr.Row(): with gr.Column(): gr.Markdown("### Grounded Visual Chat") gr.Video(value="gradio_demo/examples/demo_grounding.mp4") with gr.Column(): gr.Markdown("### Visual Prompt (Click)") gr.Video(value="gradio_demo/examples/demo_inter_click.mp4") with gr.Column(): gr.Markdown("### Visual Prompt (Box)") gr.Video(value="gradio_demo/examples/demo_inter_box.mp4") txt.submit(add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then( bot, chatbot, chatbot, api_name="bot_text_response" ) submit_btn.click(fn=add_text, inputs=[chatbot, txt, img, threshold, temperature, interaction_selector], outputs=[chatbot, txt]).then( bot, chatbot, chatbot, api_name="submit_text" ) img_upload_btn.click(fn=add_image, inputs=[chatbot, img], outputs=[chatbot], api_name="upload_image") inter_upload_btn.click(fn=add_interaction_click, inputs=[chatbot, img, interaction_selector], outputs=[chatbot], api_name="upload_inter") # buttons clear_btn.click(fn=clear_history, inputs=[chatbot, txt, img], outputs=[chatbot, txt, img], api_name="clear_all") regenerate_btn.click(fn=clear_response, inputs=[chatbot], outputs=[chatbot, txt], api_name="clear_last_response").then( add_text, [chatbot, txt, img, threshold, temperature, interaction_selector], [chatbot, txt], queue=False).then( bot, chatbot, chatbot, api_name="regenerate_response" ) upvote_btn.click(fn=upvote_one, inputs=[], outputs=[], api_name="upvote_one") downvote_btn.click(fn=downvote_one, inputs=[], outputs=[], api_name="downvote_one") flag_btn.click(fn=flag_one, inputs=[], outputs=[], api_name="flag_one") demo.queue() if __name__ == "__main__": import argparse argparser = argparse.ArgumentParser() argparser.add_argument("--server_name", default="0.0.0.0", type=str) argparser.add_argument("--port", default=12124, type=str) argparser.add_argument("--model_path", default="", type=str) argparser.add_argument("--path_vision_cfg", default="configs/openseed/openseed_swint_lang_joint_2st_v2_data_end_with_interaction.yaml", type=str) argparser.add_argument("--path_inter_cfg", default="configs/semsam/idino_swint_1_part_data_llm_ref_feat_all_16_det_pretrainv1.yaml", type=str) args = argparser.parse_args() model_path = args.model_path colors = generate_distinct_colors(20) if not os.path.exists("./gradio_demo/tmp_files"): os.makedirs("./gradio_demo/tmp_files") our_chatbot = InferenceDemo(args.model_path, args.path_vision_cfg, args.path_inter_cfg) demo.launch(server_name=args.server_name, server_port=int(args.port)) ================================================ FILE: gradio_demo/__init__.py ================================================ ================================================ FILE: llava/__init__.py ================================================ from .model import LlavaLlamaForCausalLM ================================================ FILE: llava/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" ================================================ FILE: llava/conversation.py ================================================ import dataclasses from enum import auto, Enum from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if 'mmtag' in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' ret.append([img_str, None]) msg = msg.replace('', '').strip() if len(msg) > 0: ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llava_llama_2 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mpt = Conversation( system="""<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_llava_plain = Conversation( system="", roles=("", ""), messages=( ), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", ) conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v0_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), messages=( ), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", version="v0_mmtag", ) conv_llava_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llava_v1_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", version="v1_mmtag", ) default_conversation = conv_vicuna_v0 conv_templates = { "default": conv_vicuna_v0, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "llama_2": conv_llama_2, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, "v0_mmtag": conv_llava_v0_mmtag, "llava_v1": conv_llava_v1, "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "mpt": conv_mpt, } if __name__ == "__main__": print(default_conversation.get_prompt()) ================================================ FILE: llava/eval/LLaVA_G_Eval.py ================================================ import os import cv2 import json import torch import collections import transformers import numpy as np from llava.model import * from typing import Dict from llava import conversation as conversation_lib from tqdm import tqdm from detectron2.utils.file_io import PathManager from llava.mm_utils import tokenizer_image_token from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN from llava.eval.llava_mapper import COCOInstanceNewBaselineDatasetMapper as LLAVAInstanceNewBaselineDatasetMapper grounding_start="" grounding_end="" SEG_TOKEN="" BOX_TOKEN="#B#" MARKER_TOKEN="#M#" def load_jsonl_file(path_jsonl): import jsonlines data = [] with jsonlines.open(path_jsonl, "r") as reader: for obj in reader: data.append(obj) return data def save_jsonl_file(data, path_save): import jsonlines with jsonlines.open(path_save, "w") as writer: for item in data: writer.write(item) def load_benchmark(image_root, path_benchmark): data = load_jsonl_file(path_benchmark) ret = [] for d in data: image_name = d["image"] image_id = int(image_name.split(".")[0]) image_file = os.path.join(image_root, "COCO_val2014_" + image_name) # conv = d["conversations"] conv = [ { "from": "human", "value": d["text"] }, { "from": "gpt", "value": "Placeholder." } ] conv[0]["value"] = DEFAULT_IMAGE_TOKEN + " " + conv[0]["value"] + " (with grounding)" ret.append( { "file_name": image_file, "image_id": image_id, # "grounding_info": None, "conversations": [[conv, None]], "question_id": d["question_id"] } ) assert len(ret), f"No images found in {image_root}!" assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] return ret def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conv_prompt = conv.get_prompt() conv_prompt = conv_prompt.split("ASSISTANT: ")[0] + "ASSISTANT:" conversations.append(conv_prompt) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) class Evaluator_MM: def __init__(self, model_path, path_vision_model_cfg=None): model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] print("1. Constructing model...") self.tokenizer, self.model, self.image_processor, self.context_len = self.construct_model( model_path=model_path, model_name=self.model_name, ) print(" Continue...") self.construct_vision_model(path_vision_model_cfg) print("Done.") self.image_processor=self.model.get_vision_tower().image_processor print("2. Loading Parameters...") self.load_parameters(model_path) print("Done.") self.model.eval() conversation_lib.default_conversation = conversation_lib.conv_templates["v1"] self.data_mapper = LLAVAInstanceNewBaselineDatasetMapper(self.cfg_vision_model, False, tokenizer=self.tokenizer, image_processor=self.image_processor, preprocess=preprocess_v1) def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map="auto"): import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM_gd.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def construct_vision_model(self, path_vision_model_cfg): from detectron2.config import LazyConfig from llava.model.openseed import build_model from llava.model.openseed.BaseModel import BaseModel def get_config_from_name(cfg, dataset_name="flickr"): # adjust config according to dataset, flickr by default if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct_train' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'refcoco' in dataset_name or "flickr_val" in dataset_name: cfg.update(cfg['REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg) vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model)) vision_model.eval() self.model.seg_model = vision_model self.model.seg_model.model = self.model.seg_model.model.to(self.model.device) # print("Configuring for Dataset Mapper ...") self.cfg_vision_model = get_config_from_name(self.cfg_vision_model) def load_parameters(self, path_model): print("Loading Whole Model ...") loaded_dict = dict() for model_file in os.listdir(path_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(path_model, model_file), map_location='cpu')) self.model.load_state_dict(loaded_dict, strict=True) @torch.inference_mode() def evaluate_sample(self, input_data, get_box=True, get_mask=False): text, boxes, masks = self.model.forward_eval(input_data) returns = [text,] if get_box: returns.append(boxes) if get_mask: returns.append(masks) return returns class Evaluator_MM_Inter(Evaluator_MM): def __init__(self, model_path, path_vision_model_cfg=None, path_inter_model_cfg=None): self.path_inter_model_cfg = path_inter_model_cfg super().__init__(model_path, path_vision_model_cfg) def construct_model(self, model_path, model_base=None, model_name=None, load_8bit=False, load_4bit=False, device_map="auto"): import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def construct_vision_model(self, path_vision_model_cfg): from detectron2.config import LazyConfig from llava.model.openseed import build_model from llava.model.openseed.BaseModel import BaseModel def get_config_from_name(cfg, dataset_name="flickr"): # adjust config according to dataset, flickr by default if 'sam' in dataset_name: cfg.update(cfg['SAM']) return cfg elif 'flickr' in dataset_name: cfg.update(cfg['flickr']) return cfg elif 'coco_instruct_train' in dataset_name: cfg.update(cfg['coco_instruct']) return cfg elif 'lisa' in dataset_name: cfg.update(cfg['LISA_REF']) return cfg elif 'llava' in dataset_name: cfg.update(cfg['llava']) return cfg elif 'vg' in dataset_name: cfg.update(cfg['vg']) return cfg elif 'part' in dataset_name and 'pascal_part' not in dataset_name and 'partimagenet' not in dataset_name: cfg.update(cfg['part']) return cfg elif 'pascal' in dataset_name or 'paco' in dataset_name or 'partimagenet' in dataset_name : cfg.update(cfg['PSACAL_PART']) return cfg elif 'coco' in dataset_name and 'refonly' in dataset_name: # if 'COCO' in cfg.keys(): cfg.update(cfg['COCO_REF']) return cfg elif 'refcoco' in dataset_name or "flickr_val" in dataset_name: cfg.update(cfg['REF']) return cfg elif 'coco' in dataset_name: if 'COCO' in cfg.keys(): cfg.update(cfg['COCO']) return cfg elif "mapillary" in dataset_name: if 'MAPILLARY' in cfg.keys(): cfg.update(cfg['MAPILLARY']) return cfg elif 'ade' in dataset_name: if 'ADE20K' in cfg.keys(): cfg.update(cfg['ADE20K']) return cfg elif 'imagenet' in dataset_name: if 'IMAGENET' in cfg.keys(): cfg.update(cfg['IMAGENET']) return cfg elif 'vlp' in dataset_name: cfg.update(cfg['VLP']) return cfg elif 'sun' in dataset_name: cfg.update(cfg['SUN']) return cfg elif 'object365' in dataset_name: cfg.update(cfg['OBJECT365']) return cfg elif 'scan' in dataset_name: cfg.update(cfg['SCAN']) return cfg elif 'cityscape' in dataset_name: cfg.update(cfg['CITY']) return cfg elif 'bdd' in dataset_name: cfg.update(cfg['BDD']) return cfg else: assert False, "dataset not support." self.cfg_vision_model = LazyConfig.load(path_vision_model_cfg) vision_model = BaseModel(self.cfg_vision_model, build_model(self.cfg_vision_model)) vision_model.eval() self.model.seg_model = vision_model self.model.seg_model.model = self.model.seg_model.model.to(self.model.device) self.cfg_inter_model = LazyConfig.load(self.path_inter_model_cfg) self.model.initialize_interactive_modules(self.cfg_inter_model) self.model.interactive_model.model = self.model.interactive_model.model.to(self.model.device) # print("Configuring for Dataset Mapper ...") self.cfg_vision_model = get_config_from_name(self.cfg_vision_model) @torch.inference_mode() def evaluate_sample(self, input_data): text, boxes, masks, mask_inter = self.model.forward_eval(input_data) return text, boxes, masks, mask_inter def formatting(text, boxes, question_id): def find_start_idxes(sentence, word): window_size = len(word) start_indexes = [] assert len(sentence) > window_size if sentence == window_size: return [0] for start_index in range(len(sentence) - window_size+1): if sentence[start_index: start_index + window_size] == word: start_indexes.append(start_index) return start_indexes def extract_text(sentence): # Use regular expression to find and extract the text and number import re pattern = r"|" cleaned_text = re.sub(pattern, '', sentence) return cleaned_text def multiboxes_to_str(boxes): boxes_text = [] for box in boxes: boxes_text.append(list_to_str(box)) output_string = ";".join(boxes_text) return output_string.replace("];[", ";") def list_to_str(list_): list_str = [str(round(aa, 3)) for aa in list_] return "[" + ",".join(list_str) + "]" def format_sentence(splitted_sentence): joint_sentence = " ".join(splitted_sentence) return joint_sentence text_pure = "" text_boxes = "" boxes_pure = [] number = 0 seg_start_index = find_start_idxes(text, "") if len(seg_start_index) > 0: # text = text[:tail_start_index[0]] subtexts = text.split(" ") for subtext in subtexts: if "" in subtext: # subtext += "" start_idx = find_start_idxes(subtext, "")[0] text_pure = format_sentence([text_pure, format_sentence(subtext[:start_idx].split())]) text_boxes = format_sentence([text_boxes, format_sentence(subtext[:start_idx].split())]) text_ = extract_text(subtext[start_idx:]) text_pure = format_sentence([text_pure, format_sentence(text_.split())]) if number >= len(boxes): print("Error, There should be a wrong prediction.") text_boxes = format_sentence([text_boxes, format_sentence(text_.split())]) number += 1 continue text_boxes = format_sentence([text_boxes, format_sentence(text_.split()) + multiboxes_to_str(boxes[number].cpu().tolist())]) boxes_pure.append(multiboxes_to_str(boxes[number].cpu().tolist())) number += 1 else: text_pure = format_sentence([text_pure, format_sentence(subtext.split())]) text_boxes = format_sentence([text_boxes, format_sentence(subtext.split())]) return { "question_id": question_id, "text": text_pure, "text_boxes": text_boxes, "boxes": boxes_pure, } else: return { "question_id": question_id, "text": text, "text_boxes": text, "boxes": [] } def evaluate_(path_benchmarks, dir_image, evaluator, matching_threshold): def unresize_box(box, width, height, size): # ori_size = max(width, height) # ratio = ori_size / size ratio = min(width, height) / max(width, height) if width > height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 1] = box[:, 1] / ratio box[:, 3] = box[:, 3] / ratio elif width < height: # then the height dimension is padded, the y coordinates should be divided by ratio box[:, 0] = box[:, 0] / ratio box[:, 2] = box[:, 2] / ratio return box def filter_empty_box(text, boxes_image): def extract_text(sentence): # Use regular expression to find and extract the text and number import re if " " in sentence: pattern = r"| " cleaned_text = re.sub(pattern, '', sentence) return cleaned_text else: cleaned_text = re.sub(r' \d+', '', sentence) cleaned_text = re.sub(r' ', '', cleaned_text) return cleaned_text has_gd = True if "" in text else False if len(boxes_image) == 0: return text, boxes_image else: if has_gd: sub_texts = text.split(" ") sub_texts_filtered = [] boxes_image_filtered = [] for box_per_gd, text_per_gd in zip(boxes_image, sub_texts): text_per_gd += " " ind_nonempty_box = torch.where(box_per_gd.abs().sum(dim=1)>0) if len(ind_nonempty_box[0]) < box_per_gd.shape[0]: # empty box encountered if len(ind_nonempty_box[0]) == 0: text_per_gd = " " + " ".join(extract_text(text_per_gd).split()) sub_texts_filtered.append(text_per_gd) # box is desperated continue else: box_per_gd = box_per_gd[ind_nonempty_box] boxes_image_filtered.append(box_per_gd) sub_texts_filtered.append(text_per_gd) else: boxes_image_filtered.append(box_per_gd) sub_texts_filtered.append(text_per_gd) sub_texts_filtered.append(sub_texts[-1]) text_filtered = "".join(sub_texts_filtered) return text_filtered, boxes_image_filtered else: text_filtered = " ".join(extract_text(text).split()) boxes_image_filtered = [] return text_filtered, boxes_image_filtered def debug(image, boxes, prefix): import cv2 # image = cv2.imread(path_image) def transform_str2numpy(boxes_str): boxes_str = boxes_str.replace(";", "];[") boxes_list = [] for box_str in boxes_str.split(";"): box_list = [float(aa) for aa in box_str[1:-1].split(",")] boxes_list.append(box_list) return boxes_list boxes_ = [] for box in boxes: boxes_.extend(transform_str2numpy(box)) boxes = torch.tensor(boxes_) image = image[..., ::-1] image = np.ascontiguousarray(image, dtype=np.uint8) height,width,_ = image.shape for box in boxes: box = (box.cpu() * torch.tensor([width, height, width, height])).int().squeeze() box = box.tolist() image = cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0)) cv2.imwrite(f"{prefix}_debug.jpg", image) datas = load_benchmark(dir_image, path_benchmarks)[:20] #! use first 20 samples for debug. data_mapper = evaluator.data_mapper device = evaluator.model.device outputs = [] for data in tqdm(datas): input_data = data_mapper(data)[0] for key, value in input_data.items(): if isinstance(value, torch.Tensor): input_data[key] = value.to(device) input_data["matching_threshold"] = matching_threshold text, boxes = evaluator.evaluate_sample([input_data]) text, boxes = filter_empty_box(text, boxes) boxes = [unresize_box(bb.detach().cpu(), input_data["width"], input_data["height"], 1024) for bb in boxes] output = formatting(text, boxes, input_data["question_id"]) # from ipdb import set_trace; set_trace() # debug(cv2.imread(input_data["file_name"]), output["boxes"], prefix=str(input_data["question_id"])) outputs.append(output) return outputs def evaluate(args=None): evaluator = Evaluator_MM( model_path=args.model_path, path_vision_model_cfg=args.vision_model_cfg, ) results = evaluate_(args.path_benchmark, dir_image=args.image_root, evaluator=evaluator, matching_threshold=args.matching_threshold) return results if __name__ == "__main__": import argparse args = argparse.ArgumentParser() args.add_argument("--model_path", type=str, default="xx") args.add_argument("--vision_model_cfg", type=str, default="xx") args.add_argument("--matching_threshold", type=float, default=0.2) args.add_argument("--path_benchmark", default="./dataset/qa1000_questions.jsonl") args.add_argument("--image_root", default="./dataset/coco/val2014") args = args.parse_args() results = evaluate(args) path_save = f"./LLaVA_G_{args.path_benchmark.split('/')[-1].split('.')[0]}_t{args.matching_threshold}.jsonl" #! sync print("Writing at: ", path_save) save_jsonl_file(results, path_save) # CUDA_VISIBLE_DEVICES=0 python llava/eval/LLaVA_G_Eval.py --model_path ./checkpoints/llava_stage2_new_joint_seg0.1_data_v3/checkpoint-8000 --vision_model_cfg configs/openseed/openseed_swint_lang_coco_instruct_coco_end_llava_bench.yaml --path_benchmark ./dataset/qa90_questions.jsonl --image_root ./dataset/coco/val2014 --matching_threshold 0.2 ================================================ FILE: llava/eval/eval_gpt_review.py ================================================ import argparse import json import os import openai import tqdm import ray import time NUM_SECONDS_TO_SLEEP = 3 @ray.remote(num_cpus=4) def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( model='gpt-4', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) print('success!') return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') # parser.add_argument('-a', '--answer') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() ray.init() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) review_file = open(f'{args.output}', 'w') js_list = [] handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): # if idx == 1: # break ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: rule = rule_dict['default'] prompt = rule['prompt'] role = rule['role'] content = (f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') js_list.append({ 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1['answer_id'], 'answer2_id': ans2['answer_id'], 'category': category}) idx += 1 handles.append(get_eval.remote(content, args.max_tokens)) # To avoid the rate limit set by OpenAI time.sleep(NUM_SECONDS_TO_SLEEP) reviews = ray.get(handles) for idx, review in enumerate(reviews): scores = parse_score(review) js_list[idx]['content'] = review js_list[idx]['tuple'] = scores review_file.write(json.dumps(js_list[idx]) + '\n') review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_bench.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 openai.api_type = "azure" openai.api_base = "https://xdecoder.openai.azure.com/" openai.api_version = "2023-03-15-preview" os.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032' openai.api_key = os.getenv("OPENAI_API_KEY") def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt4a', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['caption']) category = 'llava_bench_' + json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_visual.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 openai.api_type = "azure" openai.api_base = "https://xdecoder.openai.azure.com/" openai.api_version = "2023-03-15-preview" os.environ['OPENAI_API_KEY']='f0f8184713a549ba945bbcc19a06e032' openai.api_key = os.getenv("OPENAI_API_KEY") def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt4a', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['captions']) box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']) if 'answer_id' in ans2 else ans2['question_id'], 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_gpt_review_visual2.py ================================================ import argparse import json import os import openai import time NUM_SECONDS_TO_SLEEP = 0.5 os.environ['OPENAI_API_KEY']='233c45550c614b72b8f3c9309efecf06' openai.api_type = "azure" openai.api_base = 'https://azureopenaifiahmedeastus.openai.azure.com/' openai.api_version = '2023-03-15-preview' openai.api_key = "233c45550c614b72b8f3c9309efecf06" def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( engine='gpt-4-32k-0314', messages=[{ 'role': 'system', 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' }, { 'role': 'user', 'content': content, }], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) break except openai.error.RateLimitError: pass except Exception as e: print(e) time.sleep(NUM_SECONDS_TO_SLEEP) return response['choices'][0]['message']['content'] def parse_score(review): try: score_pair = review.split('\n')[0] score_pair = score_pair.replace(',', ' ') sp = score_pair.split(' ') if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: print('error', review) return [-1, -1] except Exception as e: print(e) print('error', review) return [-1, -1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-q', '--question') parser.add_argument('-c', '--context') parser.add_argument('-a', '--answer-list', nargs='+', default=[]) parser.add_argument('-r', '--rule') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) if os.path.isfile(os.path.expanduser(args.output)): cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] else: cur_reviews = [] review_file = open(f'{args.output}', 'a') context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] image_to_context = {context['image']: context for context in context_list} handles = [] idx = 0 for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): ques = json.loads(ques_js) ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) inst = image_to_context[ques['image']] cap_str = '\n'.join(inst['captions']) box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) category = json.loads(ques_js)['category'] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." prompt = rule['prompt'] role = rule['role'] content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' f'[Question]\n{ques["text"]}\n\n' f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' f'[System]\n{prompt}\n\n') cur_js = { 'id': idx+1, 'question_id': ques['question_id'], 'answer1_id': ans1.get('answer_id', ans1['question_id']), 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 'category': category } if idx >= len(cur_reviews): review = get_eval(content, args.max_tokens) scores = parse_score(review) cur_js['content'] = review cur_js['tuple'] = scores review_file.write(json.dumps(cur_js) + '\n') review_file.flush() else: print(f'Skipping {idx} as we already have it.') idx += 1 print(idx) review_file.close() ================================================ FILE: llava/eval/eval_science_qa.py ================================================ import argparse import json import os import re import random def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--result-file', type=str) parser.add_argument('--output-file', type=str) parser.add_argument('--output-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) predictions = [json.loads(line) for line in open(args.result_file)] predictions = {pred['question_id']: pred for pred in predictions} split_problems = {idx: problems[idx] for idx in split_indices} results = {'correct': [], 'incorrect': []} sqa_results = {} sqa_results['acc'] = None sqa_results['correct'] = None sqa_results['count'] = None sqa_results['results'] = {} sqa_results['outputs'] = {} for prob_id, prob in split_problems.items(): if prob_id not in predictions: continue pred = predictions[prob_id] pred_text = pred['text'] pattern = re.compile(r'The answer is ([A-Z]).') res = pattern.findall(pred_text) if len(res) == 1: answer = res[0] # 'A', 'B', ... else: answer = "FAILED" pred_idx = get_pred_idx(answer, prob['choices'], args.options) analysis = { 'question_id': prob_id, 'parsed_ans': answer, 'ground_truth': args.options[prob['answer']], 'question': pred['prompt'], 'pred': pred_text, 'is_multimodal': '' in pred['prompt'], } sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) sqa_results['outputs'][prob_id] = pred_text if pred_idx == prob['answer']: results['correct'].append(analysis) else: results['incorrect'].append(analysis) correct = len(results['correct']) total = len(results['correct']) + len(results['incorrect']) print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') sqa_results['acc'] = correct / total * 100 sqa_results['correct'] = correct sqa_results['count'] = total with open(args.output_file, 'w') as f: json.dump(results, f, indent=2) with open(args.output_result, 'w') as f: json.dump(sqa_results, f, indent=2) ================================================ FILE: llava/eval/eval_science_qa_gpt4.py ================================================ import argparse import json import os import re import random from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--gpt4-result', type=str) parser.add_argument('--our-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] our_predictions = {pred['question_id']: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] results = defaultdict(lambda: 0) for prob_id, prob in split_problems.items(): if prob_id not in our_predictions: continue if prob_id not in gpt4_predictions: continue our_pred = our_predictions[prob_id]['text'] gpt4_pred = gpt4_predictions[prob_id] pattern = re.compile(r'The answer is ([A-Z]).') our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... else: our_answer = "FAILED" gpt4_res = pattern.findall(gpt4_pred) if len(gpt4_res) == 1: gpt4_answer = gpt4_res[0] # 'A', 'B', ... else: gpt4_answer = "FAILED" our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) if gpt4_answer == 'FAILED': results['gpt4_failed'] += 1 # continue gpt4_pred_idx = our_pred_idx # if our_pred_idx != prob['answer']: # print(our_predictions[prob_id]['prompt']) # print('-----------------') # print(f'LECTURE: {prob["lecture"]}') # print(f'SOLUTION: {prob["solution"]}') # print('=====================') else: # continue pass # gpt4_pred_idx = our_pred_idx if gpt4_pred_idx == prob['answer']: results['correct'] += 1 else: results['incorrect'] += 1 if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: results['correct_upperbound'] += 1 correct = results['correct'] total = results['correct'] + results['incorrect'] print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') ================================================ FILE: llava/eval/eval_science_qa_gpt4_requery.py ================================================ import argparse import json import os import re import random from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--base-dir', type=str) parser.add_argument('--gpt4-result', type=str) parser.add_argument('--requery-result', type=str) parser.add_argument('--our-result', type=str) parser.add_argument('--output-result', type=str) parser.add_argument('--split', type=str, default='test') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result['question_id'] caption = result['text'] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] our_predictions = {pred['question_id']: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} requery_predictions = [json.loads(line) for line in open(args.requery_result)] requery_predictions = {pred['question_id']: pred for pred in requery_predictions} gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] results = defaultdict(lambda: 0) sqa_results = {} sqa_results['acc'] = None sqa_results['correct'] = None sqa_results['count'] = None sqa_results['results'] = {} sqa_results['outputs'] = {} for prob_id, prob in split_problems.items(): if prob_id not in our_predictions: assert False if prob_id not in gpt4_predictions: assert False our_pred = our_predictions[prob_id]['text'] gpt4_pred = gpt4_predictions[prob_id] if prob_id not in requery_predictions: results['missing_requery'] += 1 requery_pred = "MISSING" else: requery_pred = requery_predictions[prob_id]['text'] pattern = re.compile(r'The answer is ([A-Z]).') our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... else: our_answer = "FAILED" requery_res = pattern.findall(requery_pred) if len(requery_res) == 1: requery_answer = requery_res[0] # 'A', 'B', ... else: requery_answer = "FAILED" gpt4_res = pattern.findall(gpt4_pred) if len(gpt4_res) == 1: gpt4_answer = gpt4_res[0] # 'A', 'B', ... else: gpt4_answer = "FAILED" our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) results['total'] += 1 if gpt4_answer == 'FAILED': results['gpt4_failed'] += 1 if gpt4_pred_idx == prob['answer']: results['gpt4_correct'] += 1 if our_pred_idx == prob['answer']: results['gpt4_ourvisual_correct'] += 1 elif gpt4_pred_idx == prob['answer']: results['gpt4_correct'] += 1 results['gpt4_ourvisual_correct'] += 1 if our_pred_idx == prob['answer']: results['our_correct'] += 1 if requery_answer == 'FAILED': sqa_results['results'][prob_id] = our_pred_idx if our_pred_idx == prob['answer']: results['requery_correct'] += 1 else: sqa_results['results'][prob_id] = requery_pred_idx if requery_pred_idx == prob['answer']: results['requery_correct'] += 1 else: print(f""" Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} Our ({our_answer}): {our_pred} GPT-4 ({gpt4_answer}): {gpt4_pred} Requery ({requery_answer}): {requery_pred} print("=====================================") """) if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: results['correct_upperbound'] += 1 total = results['total'] print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') sqa_results['acc'] = results["requery_correct"] / total * 100 sqa_results['correct'] = results["requery_correct"] sqa_results['count'] = total with open(args.output_result, 'w') as f: json.dump(sqa_results, f, indent=2) ================================================ FILE: llava/eval/generate_webpage_data_from_table.py ================================================ """Generate json file for webpage.""" import json import os import re # models = ['llama', 'alpaca', 'gpt35', 'bard'] models = ['vicuna'] def read_jsonl(path: str, key: str=None): data = [] with open(os.path.expanduser(path)) as f: for line in f: if not line: continue data.append(json.loads(line)) if key is not None: data.sort(key=lambda x: x[key]) data = {item[key]: item for item in data} return data def trim_hanging_lines(s: str, n: int) -> str: s = s.strip() for _ in range(n): s = s.split('\n', 1)[1].strip() return s if __name__ == '__main__': questions = read_jsonl('table/question.jsonl', key='question_id') # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') records = [] for qid in questions.keys(): r = { 'id': qid, 'category': questions[qid]['category'], 'question': questions[qid]['text'], 'answers': { # 'alpaca': alpaca_answers[qid]['text'], # 'llama': llama_answers[qid]['text'], # 'bard': bard_answers[qid]['text'], # 'gpt35': gpt35_answers[qid]['text'], 'vicuna': vicuna_answers[qid]['text'], 'ours': ours_answers[qid]['text'], }, 'evaluations': { # 'alpaca': review_alpaca[qid]['text'], # 'llama': review_llama[qid]['text'], # 'bard': review_bard[qid]['text'], 'vicuna': review_vicuna[qid]['content'], # 'gpt35': review_gpt35[qid]['text'], }, 'scores': { 'vicuna': review_vicuna[qid]['tuple'], # 'alpaca': review_alpaca[qid]['score'], # 'llama': review_llama[qid]['score'], # 'bard': review_bard[qid]['score'], # 'gpt35': review_gpt35[qid]['score'], }, } # cleanup data cleaned_evals = {} for k, v in r['evaluations'].items(): v = v.strip() lines = v.split('\n') # trim the first line if it's a pair of numbers if re.match(r'\d+[, ]+\d+', lines[0]): lines = lines[1:] v = '\n'.join(lines) cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') r['evaluations'] = cleaned_evals records.append(r) # Reorder the records, this is optional for r in records: if r['id'] <= 20: r['id'] += 60 else: r['id'] -= 20 for r in records: if r['id'] <= 50: r['id'] += 10 elif 50 < r['id'] <= 60: r['id'] -= 50 for r in records: if r['id'] == 7: r['id'] = 1 elif r['id'] < 7: r['id'] += 1 records.sort(key=lambda x: x['id']) # Write to file with open('webpage/data.json', 'w') as f: json.dump({'questions': records, 'models': models}, f, indent=2) ================================================ FILE: llava/eval/llava_mapper.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch import PIL.Image as Image from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances from pycocotools import mask as coco_mask from llava.model.openseed.utils import configurable from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from llava import conversation as conversation_lib from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from llava.train.train_hao_seg_flickr import ,preprocess __all__ = ["COCOInstanceNewBaselineDatasetMapper"] def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks def preprocess_multimodal( sources, is_multimodal ): if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if False: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def build_transform_gen(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ if is_train: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) else: cfg_input = cfg['INPUT'] image_size = cfg_input['IMAGE_SIZE'] min_scale = cfg_input['MIN_SCALE'] max_scale = cfg_input['MAX_SCALE'] augmentation = [] # if cfg_input['RANDOM_FLIP'] != "none": # augmentation.append( # T.RandomFlip( # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", # vertical=cfg_input['RANDOM_FLIP'] == "vertical", # ) # ) augmentation.extend([ T.ResizeScale( min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size ), T.FixedSizeCrop(crop_size=(image_size, image_size)), ]) return augmentation # This is specifically designed for the COCO dataset. class COCOInstanceNewBaselineDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, tfm_gens, image_format, tokenizer, image_processor, preprocess, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[COCOInstanceNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(str(self.tfm_gens)) ) self.img_format = image_format self.is_train = is_train self.tokenizer = tokenizer self.processor = image_processor self.preprocess = preprocess @classmethod def from_config(cls, cfg, is_train=True,tokenizer=None,image_processor=None,preprocess=None): # Build augmentation tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "tfm_gens": tfm_gens, "image_format": cfg['INPUT']['FORMAT'], "tokenizer": tokenizer, "image_processor": image_processor, "preprocess": preprocess, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) #########llava image processing image_clip = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] dataset_dict["image_clip"] = image_clip ################## # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) dataset_dict["image_ori"]=image # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~ padding_mask.astype(bool) image_shape = image.shape[:2] # h, w dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) num_conversations = len(dataset_dict['conversations']) rd = np.random.choice(num_conversations) # selected_conversation, grounding_list = dataset_dict['conversations'][rd] # dataset_dict['conversation'] = [selected_conversation] selected_conversation = [aa[0] for aa in dataset_dict['conversations']] dataset_dict['conversation'] = selected_conversation sources = preprocess_multimodal( copy.deepcopy(dataset_dict['conversation']), True) #! Debug here # sources = copy.deepcopy(dataset_dict['conversation']) data_dict_conversation = self.preprocess( sources, self.tokenizer, has_image=True) data_dict_conversation = dict(input_ids=data_dict_conversation["input_ids"][0], labels=data_dict_conversation["labels"][0]) dataset_dict.update(data_dict_conversation) dataset_dict['tokenizer'] = self.tokenizer num_segs = 1 # sum([conv['value'].count('') for conv in selected_conversation]) # grounding_list= if "grounding_info" in dataset_dict and len(dataset_dict['grounding_info'])>0: anno_id2id=dict() for id,obj in enumerate(dataset_dict['grounding_info']): obj["bbox_mode"] = BoxMode.XYWH_ABS anno_id2id[obj['id']]=id id2class=[[] for _ in range(len(dataset_dict['grounding_info']))] annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict["grounding_info"] ] # assert "segmentation" in annos[0] instances = utils.annotations_to_instances(annos, image_shape,mask_format="bitmask") h, w = instances.image_size # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) if hasattr(instances, 'gt_masks'): gt_masks = instances.gt_masks # gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks.tensor if grounding_list is None: dataset_dict['grounding']=False grounding_mask=[False for _ in range(num_segs)] dataset_dict['grounding_mask']=grounding_mask else: grounding_mask=[True if g is not None else False for g in grounding_list] dataset_dict['grounding_mask']=grounding_mask new_grounding_list=[g for g in grounding_list if g is not None] if sum(grounding_mask)==0: dataset_dict['grounding']=False else: dataset_dict['grounding']=True if dataset_dict['grounding']: # assert num_segs == len(grounding_list) for grounding_id,grounding in enumerate(new_grounding_list): if grounding is not None: for annid in grounding: id2class[anno_id2id[annid]].append(grounding_id) instances.gt_classes=id2class dataset_dict["instances"] = instances else: dataset_dict['grounding'] = False grounding_mask = [False for _ in range(num_segs)] dataset_dict['grounding_mask'] = grounding_mask return [dataset_dict] ================================================ FILE: llava/eval/model_qa.py ================================================ import argparse from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria import torch import os import json from tqdm import tqdm import shortuuid from llava.conversation import default_conversation from llava.utils import disable_torch_init # new stopping implementation class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False @torch.inference_mode() def eval_model(model_name, questions_file, answers_file): # Model disable_torch_init() model_name = os.path.expanduser(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() ques_file = open(os.path.expanduser(questions_file), "r") ans_file = open(os.path.expanduser(answers_file), "w") for i, line in enumerate(tqdm(ques_file)): idx = json.loads(line)["question_id"] qs = json.loads(line)["text"] cat = json.loads(line)["category"] conv = default_conversation.copy() conv.append_message(conv.roles[0], qs) prompt = conv.get_prompt() inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) output_ids = model.generate( input_ids, do_sample=True, use_cache=True, temperature=0.7, max_new_tokens=1024, stopping_criteria=[stopping_criteria]) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] try: index = outputs.index(conv.sep, len(prompt)) except ValueError: outputs += conv.sep index = outputs.index(conv.sep, len(prompt)) outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--question-file", type=str, default="tables/question.jsonl") parser.add_argument("--answers-file", type=str, default="answer.jsonl") args = parser.parse_args() eval_model(args.model_name, args.question_file, args.answers_file) ================================================ FILE: llava/eval/model_vqa.py ================================================ import argparse import torch import os import json from tqdm import tqdm import shortuuid from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] def eval_model(args): # Model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) ans_file = open(answers_file, "w") for line in tqdm(questions): idx = line["question_id"] image_file = line["image"] qs = line["text"] cur_prompt = qs if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() try: image = Image.open(os.path.join(args.image_folder, "COCO_val2014_"+image_file)) except Exception: image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) # import pdb; pdb.set_trace() with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor.unsqueeze(0).half().cuda(), do_sample=True, temperature=args.temperature, top_p=args.top_p, num_beams=args.num_beams, # no_repeat_ngram_size=3, max_new_tokens=2048, use_cache=True) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "prompt": cur_prompt, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.jsonl") parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--conv-mode", type=str, default="llava_v1") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--num_beams", type=int, default=1) args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/model_vqa_science.py ================================================ import argparse import torch import os import json from tqdm import tqdm import shortuuid from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] def eval_model(args): # Model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) questions = json.load(open(os.path.expanduser(args.question_file), "r")) questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) ans_file = open(answers_file, "w") for i, line in enumerate(tqdm(questions)): idx = line["id"] question = line['conversations'][0] gt_ans = line["conversations"][1] qs = question['value'].replace('', '').strip() cur_prompt = qs if 'image' in line: image_file = line["image"] image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] images = image_tensor.unsqueeze(0).half().cuda() if getattr(model.config, 'mm_use_im_start_end', False): qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs cur_prompt = '' + '\n' + cur_prompt else: images = None conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=images, do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() # prompt for answer if args.answer_prompter: outputs_reasoning = outputs input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() with torch.inference_mode(): output_ids = model.generate( input_ids, images=images, do_sample=True, temperature=0.2, max_new_tokens=64, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() outputs = outputs_reasoning + '\n The answer is ' + outputs ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "prompt": cur_prompt, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.json") parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--conv-mode", type=str, default="llava_v0") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) parser.add_argument("--answer-prompter", action="store_true") args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/qa_baseline_gpt35.py ================================================ """Generate answers with GPT-3.5""" # Note: you need to be using OpenAI Python v0.27.0 for the code below to work import argparse import json import os import time import concurrent.futures import openai import tqdm import shortuuid MODEL = 'gpt-3.5-turbo' MODEL_ID = 'gpt-3.5-turbo:20230327' def get_answer(question_id: int, question: str, max_tokens: int): ans = { 'answer_id': shortuuid.uuid(), 'question_id': question_id, 'model_id': MODEL_ID, } for _ in range(3): try: response = openai.ChatCompletion.create( model=MODEL, messages=[{ 'role': 'system', 'content': 'You are a helpful assistant.' }, { 'role': 'user', 'content': question, }], max_tokens=max_tokens, ) ans['text'] = response['choices'][0]['message']['content'] return ans except Exception as e: print('[ERROR]', e) ans['text'] = '#ERROR#' time.sleep(1) return ans if __name__ == '__main__': parser = argparse.ArgumentParser(description='ChatGPT answer generation.') parser.add_argument('-q', '--question') parser.add_argument('-o', '--output') parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') args = parser.parse_args() questions_dict = {} with open(os.path.expanduser(args.question)) as f: for line in f: if not line: continue q = json.loads(line) questions_dict[q['question_id']] = q['text'] answers = [] with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [] for qid, question in questions_dict.items(): future = executor.submit(get_answer, qid, question, args.max_tokens) futures.append(future) for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): answers.append(future.result()) answers.sort(key=lambda x: x['question_id']) with open(os.path.expanduser(args.output), 'w') as f: table = [json.dumps(ans) for ans in answers] f.write('\n'.join(table)) ================================================ FILE: llava/eval/run_llava.py ================================================ import argparse import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import requests from PIL import Image from io import BytesIO def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def eval_model(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) qs = args.query if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs if 'llama-2' in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() image = load_image(args.image_file) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() print(outputs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) args = parser.parse_args() eval_model(args) ================================================ FILE: llava/eval/summarize_gpt_review.py ================================================ import json import os from collections import defaultdict import numpy as np import argparse def parse_args(): parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-d', '--dir', default=None) parser.add_argument('-f', '--files', nargs='*', default=None) parser.add_argument('-i', '--ignore', nargs='*', default=None) return parser.parse_args() if __name__ == '__main__': args = parse_args() if args.ignore is not None: args.ignore = [int(x) for x in args.ignore] if args.files is not None and len(args.files) > 0: review_files = args.files else: review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] for review_file in sorted(review_files): config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') scores = defaultdict(list) print(config) with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: for review_str in f: review = json.loads(review_str) if args.ignore is not None and review['question_id'] in args.ignore: continue if 'category' in review: scores[review['category']].append(review['tuple']) scores['all'].append(review['tuple']) else: if 'tuple' in review: scores['all'].append(review['tuple']) else: scores['all'].append(review['score']) for k, v in sorted(scores.items()): stats = np.asarray(v).mean(0).tolist() stats = [round(x, 3) for x in stats] # print(k, stats, round(stats[1]/stats[0]*100, 1)) print(k, round(stats[1]/stats[0]*100, 1)) print('=================================') ================================================ FILE: llava/eval/webpage/index.html ================================================ Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots

Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots

other logo
vicuna logo
Assistant #2 (Vicuna, our model)
GPT-4 Evaluation
This website is co-authored with GPT-4.
================================================ FILE: llava/eval/webpage/script.js ================================================ // Description: Script for the evaluation webpage. let currentQuestionIndex = 1; // Store the model name mapping for later use. modelNameMapping = { "gpt35": "ChatGPT-3.5", "gpt4": "GPT-4", "alpaca": "Alpaca-13b", "vicuna": "Vicuna-13b", "llama": "LLaMA-13b", "bard": "Bard", }; modelFigureMapping = { "vicuna": "figures/vicuna.jpeg", // Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg "gpt35": "figures/chatgpt.svg", // Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/ "bard": "figures/bard.jpg", // Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html "alpaca": "figures/alpaca.png", // Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg "llama": "figures/llama.jpg", } // Store the question data in a mapping for later use. questionMapping = {}; // Store the question ids in a mapping for later use. categoryMapping = {}; // Store the number of questions for later use. questionsCount = 0; function text2Markdown(text) { // Normalize the text for markdown rendering. text = text.trim().replaceAll('\n\n', '\n').replaceAll('\n', '\n\n'); return marked.parse(text); } function capitalizeFirstChar(str) { if (!str || str.length === 0) { return str; } return str.charAt(0).toUpperCase() + str.slice(1); } function updateQuestionSelect(question_id) { const select = document.getElementById('question-select'); // Clear the question select. select.innerHTML = ''; // Populate the question select. category = questionMapping[question_id].category; categoryMapping[category].forEach(question_id => { const question = questionMapping[question_id]; const option = document.createElement('option'); option.value = question_id; option.textContent = 'Q' + question_id.toString() + ': ' + question.question; select.appendChild(option); }); select.value = question_id; } function updateModelSelect() { const select = document.getElementById('model-select'); img_path = modelFigureMapping[select.value]; document.getElementById('other-model-figure').src = img_path; } function populateModels(models) { const select = document.getElementById('model-select'); models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = modelNameMapping[model]; select.appendChild(option); }); updateModelSelect(); } function populateQuestions(questions) { const category_select = document.getElementById('category-select'); questionsCount = questions.length; questions.forEach(question => { const option = document.createElement('option'); // Store the question data in a mapping for later use. questionMapping[question.id] = { category: question.category, question: question.question, answers: question.answers, evaluations: question.evaluations, scores: question.scores, }; // Store the question id in the category mapping. if (question.category in categoryMapping) { categoryMapping[question.category].push(question.id); } else { categoryMapping[question.category] = [question.id]; const category_option = document.createElement('option'); category_option.value = question.category; category_option.textContent = capitalizeFirstChar(question.category); category_select.appendChild(category_option); } }); // Set the default category. updateQuestionSelect(currentQuestionIndex); } function displayQuestion(index) { const question = questionMapping[index].question; document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question); displayAnswers(index); } function displayAnswers(index) { const question = questionMapping[index]; const otherModel = document.getElementById('model-select').value; // render the answers with markdown document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]); document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna); // Display evaluation score = question.scores[otherModel]; score_text = modelNameMapping[otherModel] + " " + score[0] + "/10, Vicuna-13b " + score[1] + "/10"; document.getElementById('evaluation-header').textContent = "GPT-4 Evaluation" + " (Score: " + score_text + ")"; document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]); // Update model names let assistant1_title = "Assistant #1"; // (" + modelNameMapping[otherModel] + ")"; let assistant2_title = "Assistant #2 (Vicuna-13b, our model)"; // Update scores/labels. let assistant1_score_label = score[0].toString() + '/10'; let assistant2_score_label = score[1].toString() + '/10'; const colorRed ='#fa9'; // '#eb978d'; // const colorGreen = '#c9f2c9'; const colorBlue = '#8ef'; // '#71dbf9'; const colorYellow = '#fe7'; // '#fada57'; let otherModelHeaderColor = ''; let ourModelHeaderColor = ''; // Update the winner. if (score[0] == score[1]) { assistant1_title = '🏆 ' + assistant1_title; assistant1_score_label = '🏆 ' + assistant1_score_label; assistant2_title = '🏆 ' + assistant2_title; assistant2_score_label = '🏆 ' + assistant2_score_label; otherModelHeaderColor = colorYellow; ourModelHeaderColor = colorYellow; } else if (score[0] > score[1]) { assistant1_title = '🏆 ' + assistant1_title; assistant1_score_label = '🏆 ' + assistant1_score_label; otherModelHeaderColor = colorBlue; ourModelHeaderColor = colorRed; } else if (score[0] < score[1]) { assistant2_title = '🏆 ' + assistant2_title; assistant2_score_label = '🏆 ' + assistant2_score_label; otherModelHeaderColor = colorRed; ourModelHeaderColor = colorBlue; } document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor; document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor; document.getElementById('other-model-header').textContent = assistant1_title; document.getElementById('our-model-header').textContent = assistant2_title; document.getElementById('other-score-label').textContent = assistant1_score_label; document.getElementById('our-score-label').textContent = assistant2_score_label; // Update expand buttons visibility for both cards after displaying answers // Reset the expanded state and update expand buttons visibility for both cards after displaying answers document.querySelectorAll('.expandable-card').forEach(card => { card.classList.remove('expanded'); updateExpandButtonVisibility(card); const expandBtn = card.querySelector('.expand-btn'); expandBtn.innerHTML = 'keyboard_arrow_down Show more'; // .textContent = 'Show more'; }); } document.getElementById('question-select').addEventListener('change', e => { currentQuestionIndex = parseInt(e.target.value); displayQuestion(currentQuestionIndex); }); document.getElementById('category-select').addEventListener('change', e => { let currentCategory = e.target.value; const questionIds = categoryMapping[currentCategory]; currentQuestionIndex = questionIds[0]; updateQuestionSelect(currentQuestionIndex); displayQuestion(currentQuestionIndex); }); // Update expand buttons whenever the model is changed document.getElementById('model-select').addEventListener('change', () => { displayAnswers(currentQuestionIndex); document.querySelectorAll('.expandable-card').forEach(card => { updateExpandButtonVisibility(card); }); updateModelSelect(); }); function switchQuestionAndCategory() { document.getElementById('question-select').value = currentQuestionIndex; old_category = document.getElementById('category-select').value; new_category = questionMapping[currentQuestionIndex].category; if (old_category != new_category) { document.getElementById('category-select').value = new_category; updateQuestionSelect(currentQuestionIndex); } displayQuestion(currentQuestionIndex); } document.getElementById('prev-question').addEventListener('click', () => { // Question index starts from 1. currentQuestionIndex = Math.max(1, currentQuestionIndex - 1); switchQuestionAndCategory(); }); document.getElementById('next-question').addEventListener('click', () => { // Question index starts from 1. currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1); switchQuestionAndCategory(); }); function updateExpandButtonVisibility(card) { const cardTextContainer = card.querySelector('.card-text-container'); const expandBtn = card.querySelector('.expand-btn'); if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) { expandBtn.style.display = 'flex'; } else { expandBtn.style.display = 'none'; card.classList.add('expanded'); } } document.querySelectorAll('.expand-btn').forEach(btn => { btn.addEventListener('click', e => { const card = e.target.closest('.expandable-card'); card.classList.toggle('expanded'); const more = 'keyboard_arrow_down Show more'; const less = 'keyboard_arrow_up Show less'; e.target.innerHTML = card.classList.contains('expanded') ? less : more; }); }); ================================================ FILE: llava/eval/webpage/styles.css ================================================ body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f8f9fa; } .navbar-dark .navbar-nav .nav-link { color: #f1cf68; font-size: 1.1rem; padding: 0.5rem 0.6rem; } .card-header { font-weight: bold; } .card { box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); transition: 0.3s; } .card:hover { box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); } button { transition: background-color 0.3s; } button:hover { background-color: #007bff; } @media (max-width: 767px) { .form-row .form-group { margin-bottom: 10px; } } /* Extra styles */ .expandable-card .card-text-container { max-height: 200px; overflow-y: hidden; position: relative; } .expandable-card.expanded .card-text-container { max-height: none; } .expand-btn { position: relative; display: none; background-color: rgba(255, 255, 255, 0.8); color: #510c75; border-color: transparent; } .expand-btn:hover { background-color: rgba(200, 200, 200, 0.8); text-decoration: none; border-color: transparent; color: #510c75; } .expand-btn:focus { outline: none; text-decoration: none; } .expandable-card:not(.expanded) .card-text-container:after { content: ""; position: absolute; bottom: 0; left: 0; width: 100%; height: 90px; background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); } .expandable-card:not(.expanded) .expand-btn { margin-top: -40px; } .card-body { padding-bottom: 5px; } .vertical-flex-layout { justify-content: center; align-items: center; height: 100%; display: flex; flex-direction: column; gap: 5px; } .figure-img { max-width: 100%; height: auto; } .adjustable-font-size { font-size: calc(0.5rem + 2vw); } ================================================ FILE: llava/mm_utils.py ================================================ from PIL import Image from io import BytesIO import base64 import torch from transformers import StoppingCriteria from llava.constants import IMAGE_TOKEN_INDEX def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def process_images(images, image_processor, model_cfg): return image_processor(images, return_tensors='pt')['pixel_values'] def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def tokenizer_image_token_inter(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0]:] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False ================================================ FILE: llava/model/__init__.py ================================================ from .language_model.llava_llama_gd import LlavaLlamaForCausalLM,LlavaLlamaForCausalLM_gd,LlavaLlamaForCausalLM_joint,LlavaLlamaForCausalLM_joint_2st, LlavaConfig\ ,LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig ================================================ FILE: llava/model/apply_delta.py ================================================ """ Usage: python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava import LlavaLlamaForCausalLM def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading delta") delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] += bparam print("Saving target model") delta.save_pretrained(target_model_path) delta_tokenizer.save_pretrained(target_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) args = parser.parse_args() apply_delta(args.base_model_path, args.target_model_path, args.delta_path) ================================================ FILE: llava/model/builder.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch from llava.model import * from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len ================================================ FILE: llava/model/consolidate.py ================================================ """ Usage: python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate """ import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model import * from llava.model.utils import auto_upgrade def consolidate_ckpt(src_path, dst_path): print("Loading model") auto_upgrade(src_path) src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) src_model.save_pretrained(dst_path) src_tokenizer.save_pretrained(dst_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True) parser.add_argument("--dst", type=str, required=True) args = parser.parse_args() consolidate_ckpt(args.src, args.dst) ================================================ FILE: llava/model/language_model/llava_llama.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import detectron2.utils.comm as comm from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) print(f"rank: {comm.get_rank()}",1) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) print(f"rank: {comm.get_rank()}",2) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) print(f"rank: {comm.get_rank()}",2) if not return_dict: print(f"rank: {comm.get_rank()}", 3) output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print(f"rank: {comm.get_rank()}",4) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) ================================================ FILE: llava/model/language_model/llava_llama_gd.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union IGNORE_INDEX=-100 import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM, LlavaMetaForCausalLM_gd,LlavaMetaForCausalLM_gd_interactive import transformers # @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" # tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances,tokenizer): input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :tokenizer.model_max_length] labels = labels[:, :tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) if 'image_clip' in instances[0]: images = [instance['image_clip'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs class LlavaLlamaForCausalLM_gd(LlamaForCausalLM, LlavaMetaForCausalLM_gd): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() if 'refcoco' in batched_inputs: if 'vg' in batched_inputs: llava_inputs = collator(batched_inputs['vg']+batched_inputs['refcoco'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) else: llava_inputs = collator( batched_inputs['refcoco'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) elif 'coco' in batched_inputs: llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) else: llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) ground_idx_coco=[] ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] if 'refcoco' in seg_inputs: if 'vg' in seg_inputs: vg_len=len(seg_inputs['vg']) ground_idx_flickr = ground_idx[:vg_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[:vg_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['vg_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) flickr_len = len(seg_inputs['refcoco']) ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[vg_len:vg_len+flickr_len] if 'vg' in seg_inputs else hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['refcoco_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) # seg_inputs['flickr']=seg_inputs['refcoco'] else: flickr_len=len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] zero_mask = [0 if len(idx) == 0 else 1 for idx in ground_idx] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr=ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr=padded_ground_idx_flickr!=-1 padded_ground_idx_flickr[padded_ground_idx_flickr==-1]=0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr=hidden_states[:flickr_len] ground_hs_flickr=torch.gather(hidden_states_flickr,1,padded_ground_idx_flickr[...,None].repeat(1,1,hidden_states_flickr.shape[-1])) seg_inputs['flickr_text_embeddings']=(ground_hs_flickr,padded_mask_flickr) ##########coco ground_idx_coco = ground_idx[flickr_len:] if len(ground_idx_coco)>0: for i,(idx,data) in enumerate(zip(ground_idx_coco,seg_inputs['coco'])): mask=data['grounding_mask'] ground_idx_coco[i]=idx[mask[:len(idx)]] padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco, batch_first=True, padding_value=-1) padded_mask_coco = padded_ground_idx_coco != -1 padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0 hidden_states = outputs[0] hidden_states_coco = hidden_states[flickr_len:] ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1, hidden_states_coco.shape[ -1])) seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco) ground_loss=self.seg_model(seg_inputs) if self.seg_model.model.coco_only and len(ground_idx_coco)>0: logits = self.lm_head(hidden_states_coco) else: logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs def forward_eval(self, inputs): collator=DataCollatorForSupervisedDataset() llava_inputs=collator(inputs,tokenizer=inputs[0]['tokenizer']) llava_inputs['seg_inputs']=inputs return self.forward_inner_eval(**llava_inputs) def forward_inner_eval( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True)[0] if len(seg_hidden_states)==0: return output_text, [], [] seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text, predicted_boxes, predicted_masks def auto_regressive_generate(self, attention_mask, past_key_values, inputs_embeds, output_attentions, tokenizer, return_dict, temporature=0.0 ): ######## # llm_inputs['obj_num'] = False seg_token = tokenizer.encode("")[1] seg_token_list = [] output_ids = [] output_logits = [] length = inputs_embeds.shape[1] for i in range(1000): # import pdb;pdb.set_trace() if i == 0: results = self.model( input_ids=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) else: attention_mask = cur_hidden.new_ones( 1, past_key_values[0][0].shape[-2] + 1, device="cuda") # print("Attention mask shape: ", attention_mask.shape) results = self.model( input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device), attention_mask=attention_mask, past_key_values=past_key_values, # inputs_embeds=cur_hidden, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) cur_hidden = results.hidden_states[-1][:, -1:] # last layer last token logits = self.lm_head(results[0]) cur_logits = logits[0][-1] cur_id = int(torch.argmax(cur_logits)) if temporature < 1e-4: cur_id = int(torch.argmax(cur_logits)) else: probs = torch.softmax(cur_logits / temporature, dim=-1) cur_id = int(torch.multinomial(probs, num_samples=1)) past_key_values = results.past_key_values length += 1 if cur_id==seg_token: seg_token_list.append(cur_hidden) output_ids.append(cur_id) output_logits.append(cur_logits) if tokenizer.decode(output_ids).find("")!=-1: break return output_ids,seg_token_list class LlavaLlamaForCausalLM_joint(LlavaLlamaForCausalLM_gd): def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() assert 'refcoco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs for data in batched_inputs['llava']: data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['flickr']+batched_inputs['refcoco']+batched_inputs['llava'], tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) ground_idx_coco=[] # if 'refcoco' in seg_inputs: flickr_len = len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] hidden_states = outputs[0] hidden_states_flickr = hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) # seg_inputs['flickr']=seg_inputs['refcoco'] # else: ################################################# ################################################# refcoco_len=len(seg_inputs['refcoco']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr ground_idx_refcoco=ground_idx[flickr_len:flickr_len+refcoco_len] padded_ground_idx_refcoco = torch.nn.utils.rnn.pad_sequence(ground_idx_refcoco, batch_first=True, padding_value=-1) padded_mask_refcoco=padded_ground_idx_refcoco!=-1 padded_ground_idx_refcoco[padded_ground_idx_refcoco==-1]=0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] # hidden_states = outputs[0] hidden_states_refcoco=hidden_states[flickr_len:flickr_len+refcoco_len] ground_hs_refcoco=torch.gather(hidden_states_refcoco,1,padded_ground_idx_refcoco[...,None].repeat(1,1,hidden_states_refcoco.shape[-1])) seg_inputs['refcoco_text_embeddings']=(ground_hs_refcoco,padded_mask_refcoco) ground_loss=self.seg_model(seg_inputs) # if self.seg_model.model.coco_only and len(ground_idx_coco)>0: # logits = self.lm_head(hidden_states_coco) # else: logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LlavaLlamaForCausalLM_joint_2st(LlavaLlamaForCausalLM_gd): def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs for data in batched_inputs['llava']: data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['flickr']+batched_inputs['coco']+batched_inputs['llava'], tokenizer=batched_inputs['coco'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs return self.forward_inner(**llava_inputs) def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) flickr_len = len(seg_inputs['flickr']) ground_idx = [torch.argwhere(lb == 32002)[:, 0] for lb in labels] ##########flickr # if self.seg_model.model.coco_only: ground_idx_flickr = ground_idx[:flickr_len] padded_ground_idx_flickr = torch.nn.utils.rnn.pad_sequence(ground_idx_flickr, batch_first=True, padding_value=-1) padded_mask_flickr = padded_ground_idx_flickr != -1 padded_ground_idx_flickr[padded_ground_idx_flickr == -1] = 0 # ground_idx=[[-1] if len(idx)==0 else idx for idx in ground_idx] if self.seg_model.model.detach_seg: hidden_states = outputs[0].detach() else: hidden_states = outputs[0] hidden_states_flickr = hidden_states[:flickr_len] ground_hs_flickr = torch.gather(hidden_states_flickr, 1, padded_ground_idx_flickr[..., None].repeat(1, 1, hidden_states_flickr.shape[ -1])) seg_inputs['flickr_text_embeddings'] = (ground_hs_flickr, padded_mask_flickr) ##########coco coco_len = len(seg_inputs['coco']) ground_idx_coco = ground_idx[flickr_len:flickr_len+coco_len] if len(ground_idx_coco) > 0: for i, (idx, data) in enumerate(zip(ground_idx_coco, seg_inputs['coco'])): mask = data['grounding_mask'] ground_idx_coco[i] = idx[mask[:len(idx)]] padded_ground_idx_coco = torch.nn.utils.rnn.pad_sequence(ground_idx_coco, batch_first=True, padding_value=-1) padded_mask_coco = padded_ground_idx_coco != -1 padded_ground_idx_coco[padded_ground_idx_coco == -1] = 0 # hidden_states = outputs[0] hidden_states_coco = hidden_states[flickr_len:flickr_len+coco_len] ground_hs_coco = torch.gather(hidden_states_coco, 1, padded_ground_idx_coco[..., None].repeat(1, 1, hidden_states_coco.shape[ -1])) seg_inputs['coco_text_embeddings'] = (ground_hs_coco, padded_mask_coco) ground_loss = self.seg_model(seg_inputs) hidden_states_ = outputs[0] if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: logits = self.lm_head(hidden_states_[flickr_len:]) else: logits = self.lm_head(hidden_states_) ############################################################ loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: shift_labels = labels[..., 1:][flickr_len:].contiguous() else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss['llava']=loss ground_loss['loss_total']=sum(ground_loss.values()) ignore_list=[f'_{i}' for i in range(1,10)] ignore_list.append('interm') for key in list(ground_loss.keys()): if not key.endswith('_0') and key!='llava' and key !='loss_total': ground_loss.pop(key) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr(LlamaForCausalLM, LlavaMetaForCausalLM_gd_interactive): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward(self,**batched_inputs): # print(kwargs.keys()) # images_for_llava=torch.stack([inp['image_clip'] for inp in batched_inputs['flickr']]) collator=DataCollatorForSupervisedDataset() # assert 'coco' in batched_inputs and 'flickr' in batched_inputs and 'llava' in batched_inputs and 'interactive' in batched_inputs # for data in batched_inputs['llava']: # data['image_clip']=data['image'] llava_inputs = collator( batched_inputs['interactive'], tokenizer=batched_inputs['interactive'][0]['tokenizer']) # if 'refcoco' in batched_inputs: # llava_inputs = collator( batched_inputs['refcoco'], # tokenizer=batched_inputs['refcoco'][0]['tokenizer']) # elif 'coco' in batched_inputs: # llava_inputs=collator(batched_inputs['flickr']+batched_inputs['coco'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) # else: # llava_inputs=collator(batched_inputs['flickr'],tokenizer=batched_inputs['flickr'][0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs res1= self.forward_inner(**llava_inputs) loss_dict1=res1.loss prefix1='coco.' res1.loss=res1['loss']={prefix1+k:v for k,v in loss_dict1.items()} if 'interactiveref' in batched_inputs: llava_inputs = collator( batched_inputs['interactiveref'], tokenizer=batched_inputs['interactive'][0]['tokenizer']) batched_inputs['interactive']=batched_inputs['interactiveref'] llava_inputs['seg_inputs']=batched_inputs res2= self.forward_inner(**llava_inputs) loss_dict2=res2.loss prefix2='refcoco.' res2.loss=res2['loss']={prefix2+k:v for k,v in loss_dict2.items()} res1.loss.update(res2.loss) res1.loss['loss_total']=res1.loss['coco.loss_total']+res1.loss['refcoco.loss_total'] else: res1.loss['loss_total'] = res1.loss['coco.loss_total'] return res1 def forward_inner( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict obj_feats,inter_losses=self.interactive_model.float().forward(seg_inputs['interactive'],detach=False) obj_feats=[obj_feats[i][seg_inputs['interactive'][i]['grounding_index']][None] for i in range(len(obj_feats))] num_it=len(seg_inputs['interactive']) _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images,obj_feats=obj_feats,num_it=num_it) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states_ = outputs[0] logits = self.lm_head(hidden_states_) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() # if self.seg_model.model.coco_only and len(ground_idx_coco) > 0: # shift_labels = labels[..., 1:][flickr_len:].contiguous() # else: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output ground_loss=dict() ground_loss['llava']=loss # for k,v in inter_losses.items(): # print(v.dtype) inter_losses={k:inter_losses[k].to(float) for k in inter_losses.keys()} ground_loss.update(inter_losses) # import pdb;pdb.set_trace() ground_loss['loss_total']=sum(ground_loss.values()) ignore_list=[f'_{i}' for i in range(1,10)] ignore_list.append('interm') for key in list(ground_loss.keys()): if not key.endswith('_0') and key!='llava' and key !='loss_total': ground_loss.pop(key) return CausalLMOutputWithPast( loss=ground_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def forward_eval(self, batched_inputs): if not (batched_inputs[0]["points"] is None): print("Get Interactive Data.") collator=DataCollatorForSupervisedDataset() llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs if "temporature" in batched_inputs[0].keys(): llava_inputs["temporature"] = batched_inputs[0]["temporature"] else: llava_inputs["temporature"] = 0 return self.forward_inner_eval_interactive(**llava_inputs) else: print("Do not Get Interactive Data.") collator=DataCollatorForSupervisedDataset() llava_inputs=collator(batched_inputs,tokenizer=batched_inputs[0]['tokenizer']) llava_inputs['seg_inputs']=batched_inputs if "temporature" in batched_inputs[0].keys(): llava_inputs["temporature"] = batched_inputs[0]["temporature"] else: llava_inputs["temporature"] = 0 return self.forward_inner_eval(**llava_inputs) def forward_inner_eval( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, temporature=0 ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal_NoInter(input_ids, attention_mask, past_key_values, labels, images) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict, temporature) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True)[0] if len(seg_hidden_states)==0: return output_text, [], [], None seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text, predicted_boxes, predicted_masks, None def forward_inner_eval_interactive( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, seg_inputs: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, temporature=0 ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict #! extra interaction part boxes = seg_inputs[0]['points'] seg_inputs[0]['targets'] = [dict()] seg_inputs[0]['targets'][0]['points'] = boxes if seg_inputs[0]['mode_inter'].lower() == "click": seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([0.0]) elif seg_inputs[0]['mode_inter'].lower() == "box": seg_inputs[0]['targets'][0]['pb'] = boxes.new_tensor([1.0]) seg_inputs[0]['targets'][0]['is_part'] = [0] inter_masks, _, obj_feats =self.interactive_model.forward(seg_inputs) num_it=len(seg_inputs) # _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, obj_feats=obj_feats,num_it=num_it) output_ids, seg_hidden_states = self.auto_regressive_generate(attention_mask, past_key_values, inputs_embeds, output_attentions, seg_inputs[0]["tokenizer"], return_dict, temporature) output_text = seg_inputs[0]["tokenizer"].batch_decode([output_ids], skip_special_tokens=True) if len(seg_hidden_states)==0: return output_text[0], [], None, inter_masks seg_tokens = torch.cat(seg_hidden_states, dim=1) padded_mask = seg_tokens.new_ones(seg_tokens.shape[:2]) > 0 predicted_boxes, predicted_masks=self.seg_model.model.forward_eval(seg_inputs, (seg_tokens,padded_mask)) return output_text[0], predicted_boxes, predicted_masks, inter_masks def auto_regressive_generate(self, attention_mask, past_key_values, inputs_embeds, output_attentions, tokenizer, return_dict, temporature=0.0 ): ######## # llm_inputs['obj_num'] = False seg_token = tokenizer.encode("")[1] seg_token_list = [] output_ids = [] output_logits = [] length = inputs_embeds.shape[1] for i in range(1000): # import pdb;pdb.set_trace() if i == 0: results = self.model( input_ids=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) else: attention_mask = cur_hidden.new_ones( 1, past_key_values[0][0].shape[-2] + 1, device="cuda") # print("Attention mask shape: ", attention_mask.shape) results = self.model( input_ids=torch.as_tensor([[cur_id]], device=inputs_embeds.device), attention_mask=attention_mask, past_key_values=past_key_values, # inputs_embeds=cur_hidden, use_cache=True, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict ) cur_hidden = results.hidden_states[-1][:, -1:] # last layer last token logits = self.lm_head(results[0]) cur_logits = logits[0][-1] cur_id = int(torch.argmax(cur_logits)) if temporature < 1e-4: cur_id = int(torch.argmax(cur_logits)) else: probs = torch.softmax(cur_logits / temporature, dim=-1) cur_id = int(torch.multinomial(probs, num_samples=1)) past_key_values = results.past_key_values length += 1 if cur_id==seg_token: seg_token_list.append(cur_hidden) output_ids.append(cur_id) output_logits.append(cur_logits) if tokenizer.decode(output_ids).find("")!=-1: break return output_ids,seg_token_list AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_gd) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr) ================================================ FILE: llava/model/language_model/llava_mpt.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple import warnings import torch import torch.nn.functional as F import math from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaMPTConfig(MPTConfig): model_type = "llava_mpt" class LlavaMPTModel(LlavaMetaModel, MPTModel): config_class = LlavaMPTConfig def __init__(self, config: MPTConfig): config.hidden_size = config.d_model super(LlavaMPTModel, self).__init__(config) def embed_tokens(self, x): return self.wte(x) class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMPTConfig supports_gradient_checkpointing = True def __init__(self, config): super(MPTForCausalLM, self).__init__(config) if not config.tie_word_embeddings: raise ValueError('MPTForCausalLM only supports tied word embeddings') self.transformer = LlavaMPTModel(config) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == 'inv_sqrt_d_model': logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") self.logit_scale = logit_scale def get_model(self): return self.transformer def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlavaMPTModel): module.gradient_checkpointing = value def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None: raise NotImplementedError('inputs_embeds is not implemented for MPT yet') attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get('use_cache') == False: raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') else: prefix_mask = None return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} AutoConfig.register("llava_mpt", LlavaMPTConfig) AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) ================================================ FILE: llava/model/language_model/mpt/adapt_tokenizer.py ================================================ from typing import Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] NUM_SENTINEL_TOKENS: int = 100 def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): """Adds sentinel tokens and padding token (if missing). Expands the tokenizer vocabulary to include sentinel tokens used in mixture-of-denoiser tasks as well as a padding token. All added tokens are added as special tokens. No tokens are added if sentinel tokens and padding token already exist. """ sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] tokenizer.add_tokens(sentinels_to_add, special_tokens=True) if tokenizer.pad_token is None: tokenizer.add_tokens('', special_tokens=True) tokenizer.pad_token = '' assert tokenizer.pad_token_id is not None sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids tokenizer.sentinel_token_ids = _sentinel_token_ids class AutoTokenizerForMOD(AutoTokenizer): """AutoTokenizer + Adaptation for MOD. A simple wrapper around AutoTokenizer to make instantiating an MOD-adapted tokenizer a bit easier. MOD-adapted tokenizers have sentinel tokens (e.g., ), a padding token, and a property to get the token ids of the sentinel tokens. """ @classmethod def from_pretrained(cls, *args, **kwargs): """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) return tokenizer ================================================ FILE: llava/model/language_model/mpt/attention.py ================================================ """Attention layers.""" import math import warnings from typing import Optional import torch import torch.nn as nn from einops import rearrange from packaging import version from torch import nn from .norm import LPLayerNorm def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') else: return False return original_is_causal def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) if past_key_value is not None: if len(past_key_value) != 0: k = torch.cat([past_key_value[0], k], dim=3) v = torch.cat([past_key_value[1], v], dim=2) past_key_value = (k, v) (b, _, s_q, d) = q.shape s_k = k.size(-1) if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - s_q) _s_k = max(0, attn_bias.size(3) - s_k) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') attn_weight = attn_weight + attn_bias min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) causal_mask = causal_mask.tril() causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True) out = attn_weight.to(v.dtype).matmul(v) out = rearrange(out, 'b h s d -> b s (h d)') if needs_weights: return (out, attn_weight, past_key_value) return (out, None, past_key_value) def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') if not tensor.is_cuda: raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): try: from flash_attn import bert_padding, flash_attn_interface except: raise RuntimeError('Please install flash-attn==1.0.3.post0') check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1):] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) return (output, None, past_key_value) def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): try: from .flash_attn_triton import flash_attn_func except: _installed = False if version.parse(torch.__version__) < version.parse('2.0.0'): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func except: _installed = False if not _installed: raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.') if needs_weights: raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') if key_padding_mask is not None: warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) class MultiheadAttention(nn.Module): """Multi-head self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(self.d_model, device=device) self.k_ln = layernorm_class(self.d_model, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') elif self.attn_impl == 'torch': self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.chunk(3, dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights) return (self.out_proj(context), attn_weights, past_key_value) class MultiQueryAttention(nn.Module): """Multi-Query self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) fuse_splits = (d_model, d_model + self.head_dim) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(d_model, device=device) self.k_ln = layernorm_class(self.head_dim, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') elif self.attn_impl == 'torch': self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True) return (self.out_proj(context), attn_weights, past_key_value) def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: if alibi: if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) return (1, n_heads, 1, seq_len) elif prefix_lm or use_sequence_id: return (1, 1, seq_len, seq_len) return None else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: if alibi: (device, dtype) = (attn_bias.device, attn_bias.dtype) attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype)) return attn_bias else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') def gen_slopes(n_heads, alibi_bias_max=8, device=None): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) slopes = 1.0 / torch.pow(2, m) if _n_heads != n_heads: slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] return slopes.view(1, n_heads, 1, 1) def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) if full: alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} ================================================ FILE: llava/model/language_model/mpt/blocks.py ================================================ """GPT Blocks used for the GPT Model.""" from typing import Dict, Optional, Tuple import torch import torch.nn as nn from .attention import ATTN_CLASS_REGISTRY from .norm import NORM_CLASS_REGISTRY class MPTMLP(nn.Module): def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): super().__init__() self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) self.act = nn.GELU(approximate='none') self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj._is_residual = True def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) class MPTBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): del kwargs super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] self.norm_1 = norm_class(d_model, device=device) self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) self.norm_2 = norm_class(d_model, device=device) self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) return (x, attn_weights, past_key_value) ================================================ FILE: llava/model/language_model/mpt/configuration_mpt.py ================================================ """A HuggingFace-style model configuration.""" from typing import Dict, Optional, Union from transformers import PretrainedConfig attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} class MPTConfig(PretrainedConfig): model_type = 'mpt' def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): """The MPT configuration class. Args: d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. expansion_ratio (int): The ratio of the up/down scale in the MLP. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. emb_pdrop (float): The dropout probability for the embedding layer. learned_pos_emb (bool): Whether to use learned positional embeddings attn_config (Dict): A dictionary used to configure the model's attention module: attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention attn_pdrop (float): The dropout probability for the attention layers. attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, use the default scale of ``1/sqrt(d_keys)``. prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another bi-directionally. Tokens outside the prefix use causal attention. attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates which sub-sequence each token belongs to. Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. verbose (int): The verbosity level. 0 is silent. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use multiquery_attention (bool): Whether to use multiquery attention implementation. use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. init_std (float): The standard deviation of the normal distribution used to initialize the model, if using the baseline_ parameter initialization scheme. init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options """ self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.expansion_ratio = expansion_ratio self.max_seq_len = max_seq_len self.vocab_size = vocab_size self.resid_pdrop = resid_pdrop self.emb_pdrop = emb_pdrop self.learned_pos_emb = learned_pos_emb self.attn_config = attn_config self.init_device = init_device self.logit_scale = logit_scale self.no_bias = no_bias self.verbose = verbose self.embedding_fraction = embedding_fraction self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] super().__init__(**kwargs) self._validate_config() def _set_config_defaults(self, config, config_defaults): for (k, v) in config_defaults.items(): if k not in config: config[k] = v return config def _validate_config(self): self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) if self.d_model % self.n_heads != 0: raise ValueError('d_model must be divisible by n_heads') if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('alibi only implemented with torch and triton attention.') if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") if self.init_config.get('name', None) is None: raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") if not self.learned_pos_emb and (not self.attn_config['alibi']): raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') ================================================ FILE: llava/model/language_model/mpt/custom_embedding.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class SharedEmbedding(nn.Embedding): def forward(self, input: Tensor, unembed: bool=False) -> Tensor: if unembed: return F.linear(input, self.weight) return super().forward(input) ================================================ FILE: llava/model/language_model/mpt/flash_attn_triton.py ================================================ """ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py update imports to use 'triton_pre_mlir' *Experimental* implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. Caution: - This is an *experimental* implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident that there are none left for other head dimensions. Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn't. """ import math import torch import triton_pre_mlir as triton import triton_pre_mlir.language as tl @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']}) @triton.jit def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) if BIAS_TYPE == 'vector': b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n elif BIAS_TYPE == 'matrix': b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf')) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf')) if BIAS_TYPE != 'none': if BIAS_TYPE == 'vector': if EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == 'matrix': if EVEN_M & EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32) qk = qk * softmax_scale + bias m_ij = tl.maximum(tl.max(qk, 1), lse_i) p = tl.exp(qk - m_ij[:, None]) else: m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) acc_o_scale = tl.exp(m_i - m_ij) tl.store(t_ptrs, acc_o_scale) acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) p = p.to(v.dtype) acc_o += tl.dot(p, v) m_i = m_ij l_i_new = tl.exp(lse_i - m_ij) + l_ij lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) tl.store(t_ptrs, o_scale) o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @triton.jit def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr): if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) if BIAS_TYPE == 'vector': b_ptrs = Bias + offs_n elif BIAS_TYPE == 'matrix': b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) return if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m if EVEN_M & EVEN_HEADDIM: q = tl.load(q_ptrs) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) qk = tl.dot(q, k, trans_b=True) if not EVEN_N: qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf')) if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float('-inf')) if BIAS_TYPE != 'none': tl.debug_barrier() if BIAS_TYPE == 'vector': if EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == 'matrix': if EVEN_M & EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32) qk = qk * softmax_scale + bias if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) if BIAS_TYPE == 'none': p = tl.exp(qk * softmax_scale - lse_i[:, None]) else: p = tl.exp(qk - lse_i[:, None]) if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) else: do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) if not EVEN_HEADDIM: tl.debug_barrier() Di = tl.load(D + offs_m_curr) ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) dk += tl.dot(ds, q, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: dq = tl.load(dq_ptrs, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, eviction_policy='evict_last') elif EVEN_HEADDIM: dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy='evict_last') else: dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy='evict_last') dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy='evict_last') else: dq = tl.dot(ds, k) if EVEN_M & EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq) elif EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) else: tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom if BIAS_TYPE == 'matrix': b_ptrs += BLOCK_M * stride_bm dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) def init_to_zero(name): return lambda nargs: nargs[name].zero_() @triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']) @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']}) @triton.jit def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads Q += off_b * stride_qb + off_h * stride_qh K += off_b * stride_kb + off_h * stride_kh V += off_b * stride_vb + off_h * stride_vh DO += off_b * stride_dob + off_h * stride_doh DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh if BIAS_TYPE != 'none': Bias += off_b * stride_bb + off_h * stride_bh D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) else: start_n = tl.program_id(0) _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert k.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d) assert d <= 128, 'FlashAttention only support head dimensions up to 128' assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(d) has_bias = bias is not None bias_type = 'none' if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 if bias.stride(-1) != 1: bias = bias.contiguous() if bias.shape[2:] == (1, seqlen_k): bias_type = 'vector' elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = 'matrix' else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)') bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK = 128 num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads) _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1) return (o, lse, softmax_scale) def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): if do.stride(-1) != 1: do = do.contiguous() (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(d) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads) _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM) has_bias = bias is not None bias_type = 'none' if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 assert bias.stride(-1) == 1 if bias.shape[2:] == (1, seqlen_k): bias_type = 'vector' elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = 'matrix' else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)') bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1, batch * nheads) _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM) dq.copy_(dq_accum) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): """ qkv: (batch, seqlen, 3, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) """ if qkv.stride(-1) != 1: qkv = qkv.contiguous() (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(qkv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (qkv, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dqkv = torch.empty_like(qkv) _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dqkv, None, None, None) flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): """ q: (batch, seqlen_q, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(q, kv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, kv, o, lse, bias) = ctx.saved_tensors if len(ctx.needs_input_grad) >= 3: assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dq, dkv, None, None, None) flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) k, v: (batch_size, seqlen_k, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(q, k, v, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, k, v, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' with torch.inference_mode(): dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return (dq, dk, dv, None, None, None) flash_attn_func = FlashAttnFunc.apply ================================================ FILE: llava/model/language_model/mpt/hf_prefixlm_converter.py ================================================ """Converts Huggingface Causal LM to Prefix LM. Conversion does lightweight surgery on a HuggingFace Causal LM to convert it to a Prefix LM. Prefix LMs accepts a `bidirectional_mask` input in `forward` and treat the input prompt as the prefix in `generate`. """ import math import warnings from types import MethodType from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES: """Converts a GPT-style Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, _SUPPORTED_GPT_MODELS) assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models' def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: """Helper that gets a list of the model's attention modules. Each module has a `bias` buffer used for causal masking. The Prefix LM conversion adds logic to dynamically manipulate these biases to support Prefix LM attention masking. """ attn_modules = [] if isinstance(model, GPTNeoXForCausalLM): blocks = model.gpt_neox.layers else: blocks = model.transformer.h for block in blocks: if isinstance(model, GPTNeoForCausalLM): if block.attn.attention_type != 'global': continue attn_module = block.attn.attention elif isinstance(model, GPTNeoXForCausalLM): attn_module = block.attention else: attn_module = block.attn attn_modules.append(attn_module) return attn_modules setattr(model, '_original_forward', getattr(model, 'forward')) setattr(model, '_original_generate', getattr(model, 'generate')) def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): """Wraps original forward to enable PrefixLM attention.""" def call_og_forward(): if isinstance(self, GPTNeoXForCausalLM): return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) else: return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) if bidirectional_mask is None: return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) attn_modules = _get_attn_modules(model) (b, s) = bidirectional_mask.shape max_length = attn_modules[0].bias.shape[-1] if s > max_length: raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).') assert s <= max_length if s < max_length: pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device) bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1) bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1) for attn_module in attn_modules: attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional) output = call_og_forward() for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM attention.""" attn_modules = _get_attn_modules(model) for attn_module in attn_modules: attn_module.bias.data[:] = 1 output = self._original_generate(*args, **kwargs) for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'generate', MethodType(generate, model)) setattr(model, '_prefix_lm_converted', True) return model def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM: """Converts a BLOOM Causal LM to a Prefix LM. Supported HuggingFace model classes: - `BloomForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, BloomForCausalLM) assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models' def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: combined_attention_mask = None device = attention_mask.device (_, src_length) = input_shape if src_length > 1: combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) if bidirectional_mask is not None: assert attention_mask.shape == bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask return combined_attention_mask def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: num_heads = self.config.n_head closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1) ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1) diffs = qa - ka + key_length - query_length diffs = -diffs.abs() alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length) alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length) return alibi.to(dtype) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop('position_ids', False) is not False: warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning) if len(deprecated_arguments) > 0: raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') elif input_ids is not None: (batch_size, seq_length) = input_ids.shape elif inputs_embeds is not None: (batch_size, seq_length, _) = inputs_embeds.shape else: raise ValueError('You have to specify either input_ids or inputs_embeds') if past_key_values is None: past_key_values = tuple([None] * len(self.h)) head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(inputs_embeds) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: tmp = past_key_values[0][0] past_key_values_length = tmp.shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if self.gradient_checkpointing and self.training: if use_cache: logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...') use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i]) else: outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: oa = (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + oa hidden_states = self.ln_f(hidden_states) if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if not return_dict: return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)) return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions) setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) setattr(model.transformer, 'forward', MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """Replacement forward method for BloomCausalLM.""" if deprecated_arguments.pop('position_ids', False) is not False: warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning) if len(deprecated_arguments) > 0: raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() (batch_size, seq_length, vocab_size) = shift_logits.shape loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions) def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict: if past: input_ids = input_ids[:, -1].unsqueeze(-1) bidirectional_mask = None if past[0][0].shape[0] == input_ids.shape[0]: past = self._convert_to_bloom_cache(past) else: bidirectional_mask = torch.ones_like(input_ids) return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask} setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model)) setattr(model, '_prefix_lm_converted', True) return model def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM: """Converts an OPT Causal LM to a Prefix LM. Supported HuggingFace model classes: - `OPTForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, '_prefix_lm_converted'): return model assert isinstance(model, OPTForCausalLM) assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models' setattr(model, '_original_forward', getattr(model, 'forward')) setattr(model, '_original_generate', getattr(model, 'generate')) model.model.decoder.bidirectional_mask = None def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): combined_attention_mask = None if input_shape[-1] > 1: if self.bidirectional_mask == 'g': (bsz, src_length) = input_shape combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) else: combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) if self.bidirectional_mask is not None: assert attention_mask.shape == self.bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) if attention_mask is not None: expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask return combined_attention_mask setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder)) def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): def call_og_forward(): return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) if bidirectional_mask is None: return call_og_forward() self.model.decoder.bidirectional_mask = bidirectional_mask try: outputs = call_og_forward() except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return outputs def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM-style attention.""" self.model.decoder.bidirectional_mask = 'g' try: output = self._original_generate(*args, **kwargs) except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return output setattr(model, 'forward', MethodType(forward, model)) setattr(model, 'generate', MethodType(generate, model)) setattr(model, '_prefix_lm_converted', True) return model _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM) CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM] def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES: """Converts a HuggingFace Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` - `BloomForCausalLM` - `OPTForCausalLM` Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the `generate` method and/or select underlying methods depending on the model class. These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask". Notes on training: To actually train the converted model as a Prefix LM, training batches will need to indicate the prefix/target structure by including `bidirectional_mask` as part of the batch inputs. **This is not a standard input and requires custom layers either within or after your dataloader.** In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels` such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`. That is, the prefix portion of the sequence should not generate any loss. Loss should only be generated by the target portion of the sequence. Notes on `GPTNeoForCausalLM`: To simplify the implementation, "global" and "local" attention layers are handled differently. For "global" layers, we handle conversion as described above. For "local" layers, which use a causal attention mask within a restricted local window, we do not alter the masking. Notes on `forward` method conversion: After conversion, the `forward` method will handle a new input, `bidirectional_mask`, which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions belonging to the prefix (prefix tokens can attend to one another bidirectionally), and 0 indicates token positions belonging to the target. The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset the causal masks before returning the result. Notes on `generate` method conversion: After conversion, the `generate` method will have the same signature but will internally convert all causal masks to be purely bidirectional, call the original `generate` method, and (where appropriate) reset the causal masks before returning the result. This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and previously-generated tokens (also as expected in a Prefix LM). To preserve the API, the original methods are renamed to `_original_forward` and `_original_generate`, and replaced with new `forward` and `generate` methods that wrap them, respectively. Although implementation details vary by model class. """ if isinstance(model, _SUPPORTED_GPT_MODELS): return _convert_gpt_causal_lm_to_prefix_lm(model) elif isinstance(model, BloomForCausalLM): return _convert_bloom_causal_lm_to_prefix_lm(model) elif isinstance(model, OPTForCausalLM): return _convert_opt_causal_lm_to_prefix_lm(model) else: raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}') def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): """Attempts to add bidirectional_mask to batch if missing. Raises: KeyError if bidirectional_mask is missing and can't be inferred """ if 'bidirectional_mask' not in batch: if batch.get('mode', None) == 'icl_task': batch['bidirectional_mask'] = batch['attention_mask'].clone() for (i, continuation_indices) in enumerate(batch['continuation_indices']): batch['bidirectional_mask'][i, continuation_indices] = 0 elif 'labels' in batch and 'attention_mask' in batch: batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) else: raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') ================================================ FILE: llava/model/language_model/mpt/meta_init_context.py ================================================ from contextlib import contextmanager import torch import torch.nn as nn @contextmanager def init_empty_weights(include_buffers: bool=False): """Meta initialization context manager. A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM. Args: include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn # Initialize a model with 100 billions parameters in no time and without using any RAM. with init_empty_weights(): tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) ``` Any model created under this context manager has no weights. As such you can't do something like `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. """ with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: yield f @contextmanager def init_on_device(device: torch.device, include_buffers: bool=False): """Device initialization context manager. A context manager under which models are initialized with all parameters on the specified device. Args: device (`torch.device`): Device to initialize all parameters on. include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn with init_on_device(device=torch.device("cuda")): tst = nn.Liner(100, 100) # on `cuda` device ``` """ old_register_parameter = nn.Module.register_parameter if include_buffers: old_register_buffer = nn.Module.register_buffer def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) def register_empty_buffer(module, name, buffer): old_register_buffer(module, name, buffer) if buffer is not None: module._buffers[name] = module._buffers[name].to(device) if include_buffers: tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} else: tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): kwargs['device'] = device return fn(*args, **kwargs) return wrapper try: nn.Module.register_parameter = register_empty_parameter if include_buffers: nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) yield finally: nn.Module.register_parameter = old_register_parameter if include_buffers: nn.Module.register_buffer = old_register_buffer for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): setattr(torch, torch_function_name, old_torch_function) ================================================ FILE: llava/model/language_model/mpt/modeling_mpt.py ================================================ """A simple, flexible implementation of a GPT model. Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .attention import attn_bias_shape, build_attn_bias from .blocks import MPTBlock from .custom_embedding import SharedEmbedding from .norm import NORM_CLASS_REGISTRY from .configuration_mpt import MPTConfig from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm from .meta_init_context import init_empty_weights from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ try: from .flash_attn_triton import flash_attn_func except: pass Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' _no_split_modules = ['MPTBlock'] class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): config._validate_config() super().__init__(config) self.attn_impl = config.attn_config['attn_impl'] self.prefix_lm = config.attn_config['prefix_lm'] self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] self.alibi = config.attn_config['alibi'] self.alibi_bias_max = config.attn_config['alibi_bias_max'] if config.init_device == 'mixed': if dist.get_local_rank() == 0: config.init_device = 'cpu' else: config.init_device = 'meta' if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] self.embedding_fraction = config.embedding_fraction self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device) if not self.alibi: self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device) self.emb_drop = nn.Dropout(config.emb_pdrop) self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]) self.norm_f = norm_class(config.d_model, device=config.init_device) if config.init_device != 'meta': print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.') self.apply(self.param_init_fn) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id) if config.no_bias: for module in self.modules(): if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter): if config.verbose: warnings.warn(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) if config.verbose and config.verbose > 2: print(self) if 'verbose' not in self.config.init_config: self.config.init_config['verbose'] = self.config.verbose if self.config.init_config['verbose'] > 1: init_fn_name = self.config.init_config['name'] warnings.warn(f'Using {init_fn_name} initialization.') self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value @torch.no_grad() def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) self._attn_bias_initialized = True if self.attn_impl == 'flash': return (self.attn_bias, attention_mask) if self.attn_bias is not None: self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, torch.Tensor) assert isinstance(prefix_mask, torch.Tensor) attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: s_k = attention_mask.shape[-1] if attn_bias is None: attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) else: _s_k = max(0, attn_bias.size(-1) - s_k) attn_bias = attn_bias[:, :, :, _s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): (s_k, s_q) = attn_bias.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') attn_bias = attn_bias[..., :seq_len, :seq_len] causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: attention_mask = attention_mask.bool() if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: raise NotImplementedError('return_dict False is not implemented yet for MPT') if output_attentions: if self.attn_impl != 'torch': raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: raise NotImplementedError('MPT does not support training with left padding.') if self.prefix_lm and prefix_mask is None: raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') elif self.attn_uses_sequence_id is False and sequence_id is not None: warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') if input_ids is not None: S = input_ids.size(1) assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' tok_emb = self.wte(input_ids) else: assert inputs_embeds is not None assert self.alibi, 'inputs_embeds is not implemented for MPT unless for alibi.' S = inputs_embeds.size(1) tok_emb = inputs_embeds if self.alibi: x = tok_emb else: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') past_position = past_key_values[0][0].size(1) if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) if S + past_position > self.config.max_seq_len: raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) if attention_mask is not None: pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) pos_emb = self.wpe(pos) x = tok_emb + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) else: x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction) assert isinstance(self.emb_drop, nn.Module) x = self.emb_drop(x_shrunk) (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for (b_idx, block) in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) past_key_value = past_key_values[b_idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(block, x, past_key_value, attn_bias, attention_mask, self.is_causal) else: (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) if past_key_values is not None: past_key_values[b_idx] = past_key_value if output_attentions: assert all_self_attns is not None all_self_attns = all_self_attns + (attn_weights,) x = self.norm_f(x) if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns) def param_init_fn(self, module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) if not config.tie_word_embeddings: raise ValueError('MPTForCausalLM only supports tied word embeddings') print(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer = MPTModel(config) for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue if isinstance(child, torch.nn.Module): child._fsdp_wrap = True self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == 'inv_sqrt_d_model': logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") self.logit_scale = logit_scale def get_input_embeddings(self): return self.transformer.wte def set_input_embeddings(self, value): self.transformer.wte = value def get_output_embeddings(self): return self.transformer.wte def set_output_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings def set_decoder(self, decoder): self.transformer = decoder def get_decoder(self): return self.transformer def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if inputs_embeds is not None: raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).') outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def param_init_fn(self, module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None: raise NotImplementedError('inputs_embeds is not implemented for MPT yet') attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get('use_cache') == False: raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') else: prefix_mask = None return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} @staticmethod def _reorder_cache(past_key_values, beam_idx): """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 for an example in transformers. """ reordered_past = [] for layer_past in past_key_values: reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] return reordered_past ================================================ FILE: llava/model/language_model/mpt/norm.py ================================================ import torch def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == 'cuda': dtype = torch.get_autocast_gpu_dtype() elif tensor.device.type == 'cpu': dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor class LPLayerNorm(torch.nn.LayerNorm): def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) def rms_norm(x, weight=None, eps=1e-05): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight return output class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__() self.eps = eps if weight: self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) else: self.register_parameter('weight', None) def forward(self, x): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) class LPRMSNorm(RMSNorm): def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} ================================================ FILE: llava/model/language_model/mpt/param_init_fns.py ================================================ import math import warnings from collections.abc import Sequence from functools import partial from typing import Optional, Tuple, Union import torch from torch import nn from .norm import NORM_CLASS_REGISTRY def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f"Initializing network using module's reset_parameters attribute") if hasattr(module, 'reset_parameters'): module.reset_parameters() def fused_init_helper_(module: nn.Module, init_fn_): _fused = getattr(module, '_fused', None) if _fused is None: raise RuntimeError(f'Internal logic error') (dim, splits) = _fused splits = (0, *splits, module.weight.size(dim)) for (s, e) in zip(splits[:-1], splits[1:]): slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) init_fn_(module.weight[slice_indices]) def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'If model has bias parameters they are initialized to 0.') init_div_is_residual = init_div_is_residual if init_div_is_residual is False: div_is_residual = 1.0 elif init_div_is_residual is True: div_is_residual = math.sqrt(2 * n_layers) elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): div_is_residual = init_div_is_residual elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): div_is_residual = float(init_div_is_residual) else: div_is_residual = 1.0 raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') if init_div_is_residual is not False: if verbose > 1: warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') if isinstance(module, nn.Linear): if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) else: init_fn_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) if init_div_is_residual is not False and getattr(module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) elif isinstance(module, nn.Embedding): if emb_init_std is not None: std = emb_init_std if std == 0: warnings.warn(f'Embedding layer initialized to 0.') emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) if verbose > 1: warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') elif emb_init_uniform_lim is not None: lim = emb_init_uniform_lim if isinstance(lim, Sequence): if len(lim) > 2: raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') if lim[0] == lim[1]: warnings.warn(f'Embedding layer initialized to {lim[0]}.') else: if lim == 0: warnings.warn(f'Embedding layer initialized to 0.') lim = [-lim, lim] (a, b) = lim emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) if verbose > 1: warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') else: emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): if verbose > 1: warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') if hasattr(module, 'weight') and module.weight is not None: torch.nn.init.ones_(module.weight) if hasattr(module, 'bias') and module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): if module._qkv_same_embed_dim: assert module.in_proj_weight is not None assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) assert d_model is not None _d = d_model splits = (0, _d, 2 * _d, 3 * _d) for (s, e) in zip(splits[:-1], splits[1:]): init_fn_(module.in_proj_weight[s:e]) else: assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) assert module.in_proj_weight is None init_fn_(module.q_proj_weight) init_fn_(module.k_proj_weight) init_fn_(module.v_proj_weight) if module.in_proj_bias is not None: torch.nn.init.zeros_(module.in_proj_bias) if module.bias_k is not None: torch.nn.init.zeros_(module.bias_k) if module.bias_v is not None: torch.nn.init.zeros_(module.bias_v) init_fn_(module.out_proj.weight) if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): with torch.no_grad(): module.out_proj.weight.div_(div_is_residual) if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) else: for _ in module.parameters(recurse=False): raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') def _normal_init_(std, mean=0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs init_fn_ = _normal_init_(std=std) if verbose > 1: warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs if init_std is None: raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): del kwargs std = math.sqrt(2 / (5 * d_model)) _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): """From section 2.3.1 of GPT-NeoX-20B: An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py """ del kwargs residual_div = n_layers / math.sqrt(10) if verbose > 1: warnings.warn(f'setting init_div_is_residual to {residual_div}') small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): del kwargs if verbose > 1: warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): del kwargs xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) if verbose > 1: warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) if verbose > 1: warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} ================================================ FILE: llava/model/llava_arch.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn as nn from .multimodal_encoder.builder import build_vision_tower from .openseed import build_model from .openseed.BaseModel import BaseModel grounding_start="" grounding_end="" SEG_TOKEN="" from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN class LlavaMetaModel: def __init__(self, config): super(LlavaMetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter self.config.mm_vision_tower = vision_tower vision_tower = build_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower self.config.use_mm_proj = True self.config.mm_hidden_size = vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature if not hasattr(self, 'mm_projector'): self.mm_projector = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] # inp_embs = self.get_input_embeddings().weight.data[nums] # out_embs = self.get_output_embeddings().weight.data[nums] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: # print("Emb length:", len(self.get_input_embeddings().weight.data)) # if len(self.get_input_embeddings().weight.data) > 0: # if len(self.get_input_embeddings().weight.data) > 0: # self.get_input_embeddings().weight.data[-num_new_tokens:] = inp_embs # self.get_output_embeddings().weight.data[-num_new_tokens:] = out_embs input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True class LlavaMetaForCausalLM_gd(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()["weight"].dtype)) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True def initialize_seg_modules(self, cfg): seg_model = BaseModel(cfg, build_model(cfg)) seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.seg_model = seg_model def freeze_seg_modules(self): for p in self.seg_model.parameters(): p.requires_grad = False class LlavaMetaForCausalLM_gd_interactive(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features.to(self.get_model().mm_projector.state_dict()["weight"].dtype)) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images,obj_feats=None,num_it=0 ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if batch_idx >= len(input_ids) - num_it: obj_idx = cur_input_ids == 1273 idx_in_inter=batch_idx-(len(input_ids)-num_it) cur_new_input_embeds[-1][obj_idx] = obj_feats[idx_in_inter].to(cur_new_input_embeds[-1].dtype) if labels is not None: cur_labels[cur_labels==1273]=IGNORE_INDEX cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def prepare_inputs_labels_for_multimodal_NoInter( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] # st_inp=self.tokenizer.encode(grounding_start)[1] # st_out=self.tokenizer.encode(grounding_start)[1] with torch.no_grad(): self.get_input_embeddings().weight[:-3] = orig_embeds_params_in[:-3].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:-3] = orig_embeds_params_out[:-3].data for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1])) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # import pdb; pdb.set_trace() num_new_tokens = tokenizer.add_tokens([grounding_start, grounding_end, SEG_TOKEN], special_tokens=True) inits=['[',']','.'] nums=[tokenizer.encode(init)[1] for init in inits] # inp_embs = self.get_input_embeddings().weight.data[nums] # out_embs = self.get_output_embeddings().weight.data[nums] self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: self.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().cuda(), self.get_output_embeddings().weight.data.clone().cuda()] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True def initialize_seg_modules(self, cfg): seg_model = BaseModel(cfg, build_model(cfg)) seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.seg_model = seg_model def initialize_interactive_modules(self, cfg): from .semsam.BaseModel import BaseModel as SemSamBaseModel from .semsam import build_model as build_semsam_model seg_model = SemSamBaseModel(cfg, build_semsam_model(cfg)) if not (cfg.MODEL.WEIGHTS == "None"): seg_model = seg_model.from_pretrained(cfg.MODEL.WEIGHTS) self.interactive_model = seg_model def freeze_seg_modules(self): for p in self.seg_model.parameters(): p.requires_grad = False ================================================ FILE: llava/model/make_delta.py ================================================ """ Usage: python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model.utils import auto_upgrade def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading target model") auto_upgrade(target_model_path) target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Calculating delta") for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data -= base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam print("Saving delta") if hub_repo_id: kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} else: kwargs = {} target.save_pretrained(delta_path, **kwargs) target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) target_tokenizer.save_pretrained(delta_path, **kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) parser.add_argument("--hub-repo-id", type=str, default=None) args = parser.parse_args() make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) ================================================ FILE: llava/model/multimodal_encoder/builder.py ================================================ from .clip_encoder import CLIPVisionTower def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) if vision_tower.startswith("openai") or vision_tower.startswith("laion"): return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f'Unknown vision tower: {vision_tower}') ================================================ FILE: llava/model/multimodal_encoder/clip_encoder.py ================================================ import torch import torch.nn as nn from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') if not delay_load: self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) def load_model(self): self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/") self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 ================================================ FILE: llava/model/openseed/BaseModel.py ================================================ import os import logging import torch import torch.nn as nn # from utils.model import align_and_update_state_dicts logger = logging.getLogger(__name__) def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): model_keys = sorted(model_state_dict.keys()) ckpt_keys = sorted(ckpt_state_dict.keys()) result_dicts = {} matched_log = [] unmatched_log = [] unloaded_log = [] for model_key in model_keys: model_weight = model_state_dict[model_key] if model_key in ckpt_keys: ckpt_weight = ckpt_state_dict[model_key] if model_weight.shape == ckpt_weight.shape: result_dicts[model_key] = ckpt_weight ckpt_keys.pop(ckpt_keys.index(model_key)) matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unmatched_log.append( "*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) # if is_main_process(): # for info in matched_log: # logger.info(info) # for info in unloaded_log: # logger.warning(info) # for key in ckpt_keys: # logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) # for info in unmatched_log: # logger.warning(info) return result_dicts class BaseModel(nn.Module): def __init__(self, opt, module: nn.Module): super(BaseModel, self).__init__() self.opt = opt self.model = module def forward(self, *inputs, **kwargs): outputs = self.model(*inputs, **kwargs) return outputs def save_pretrained(self, save_dir): torch.save(self.model.state_dict(), save_dir) def from_pretrained(self, load_dir): state_dict = torch.load(load_dir, map_location='cpu') state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) self.model.load_state_dict(state_dict, strict=False) return self ================================================ FILE: llava/model/openseed/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .architectures import build_model ================================================ FILE: llava/model/openseed/architectures/__init__.py ================================================ from .openseed_model import * # from .openseed_model_decouple_train import * from .build import build_model ================================================ FILE: llava/model/openseed/architectures/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_model(config, **kwargs): model_name = config['MODEL']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/architectures/openseed_model.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Tuple import torch from torch import nn from torch.nn import functional as F from .registry import register_model from ..utils import configurable, box_ops #, get_class_names from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import random class OpenSeeD(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, criterion: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, merge_class=False, coco_only=False, detach_seg=False, eval_train=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.criterion = criterion self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata if size_divisibility < 0: # use backbone size_divisibility if not set size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.detach_seg=detach_seg self.eval_train=eval_train # additional args self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on} self.num_correct_gd=0 self.num_total_gd=0 self.num_correct_ref = 0 self.num_total_ref = 0 self.num_correct_coco = 0 self.num_total_coco = 0 self.coco_only=coco_only self.loss_dict=None self.mean_iou=0.0 ######## self.total_union=0.0 self.total_intersection=0.0 # self.cIoU=0.0 print("self.task_switch ", self.task_switch) # HACK for only two datasets for seg and det if not self.semantic_on: assert self.sem_seg_postprocess_before_inference @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] # Loss parameters: deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0":box_weight,"loss_giou_0":giou_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] dn = 'no' # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k!="loss_mask" and k!="loss_dice" }) dn_losses=["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses=["dn_labels", "masks", "boxes"] else: dn_losses=[] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["labels", "masks","boxes"] else: losses = ["labels", "masks"] # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10), 'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)} weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0) weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()} # building criterion criterion = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'], ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) # lang_encoder = build_language_encoder(cfg) sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion": criterion, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "coco_only": dec_cfg.get('COCO_ONLY', False), "detach_seg": cfg.get('detach_seg', False), "eval_train": cfg.get('eval_train', False), } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs, inference_task='seg'): # import ipdb; ipdb.set_trace() # print("Num images per batch:",len(batched_inputs['flickr'])) if self.training: losses = {} losses_ = dict() if 'flickr' in batched_inputs and not self.coco_only: self.criterion.conversation=False losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg',default_text_embeddings=batched_inputs['flickr_text_embeddings'],data_type='gd') for key, value in losses_flickr.items(): losses_['flickr.'+str(key)] = losses_flickr[key] self.loss_dict=losses_flickr if 'refcoco' in batched_inputs and not self.coco_only: self.criterion.conversation=False losses_ref = self.forward_seg(batched_inputs['refcoco'], task='seg',default_text_embeddings=batched_inputs['refcoco_text_embeddings'],data_type='ref') for key, value in losses_ref.items(): losses_['refcoco.'+str(key)] = losses_ref[key] if 'vg' in batched_inputs and not self.coco_only: self.criterion.conversation = False losses_ref = self.forward_seg(batched_inputs['vg'], task='det', default_text_embeddings=batched_inputs['vg_text_embeddings'], data_type='ref') for key, value in losses_ref.items(): losses_['vg.' + str(key)] = losses_ref[key] # self.loss_dict=losses_flickr if 'coco' in batched_inputs: # if self.loss_dict is None: # # else: valid_idx=[] for idx,input in enumerate(batched_inputs['coco']): if input['grounding']: valid_idx.append(idx) if len(valid_idx)==0: self.criterion.conversation = True losses_flickr = self.forward_seg(batched_inputs['flickr'], task='seg', default_text_embeddings=batched_inputs[ 'flickr_text_embeddings'], data_type='coco') self.loss_dict = losses_flickr for key, value in self.loss_dict.items(): losses['coco.' + str(key)] = self.loss_dict[key] * 0.0 else: batched_inputs['coco']=[batched_inputs['coco'][idx] for idx in valid_idx] text_embed=batched_inputs['coco_text_embeddings'] text_embed=text_embed[0][valid_idx],text_embed[1][valid_idx] self.criterion.conversation = True losses_coco_instruct = self.forward_seg(batched_inputs['coco'], task='seg',default_text_embeddings=text_embed) for key, value in losses_coco_instruct.items(): losses['coco.'+str(key)] = losses_coco_instruct[key] losses.update(losses_) # if self.task_switch['coco']: # self.criterion.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80 # # self.criterion.num_classes = 133 # task = 'seg' # if not self.coco_mask_on: # task = 'det' # # import ipdb; ipdb.set_trace() # losses_coco = self.forward_seg(batched_inputs['coco'], task=task) # new_losses_coco = {} # for key, value in losses_coco.items(): # new_losses_coco['coco.'+str(key)] = losses_coco[key] # losses.update(new_losses_coco) # if self.task_switch['o365']: # self.criterion.num_classes = 365 # losses_o365 = self.forward_seg(batched_inputs['o365'], task='det') # new_losses_o365 = {} # for key, value in losses_o365.items(): # new_losses_o365['o365.'+str(key)] = losses_o365[key] # losses.update(new_losses_o365) return losses else: processed_results = self.forward_seg(batched_inputs, task=inference_task) return processed_results def forward_seg(self, batched_inputs, task='seg',default_text_embeddings=None,data_type='gd'): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) # features={k:v.to(torch.bfloat16) for k,v in features.items()} if self.training: # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings) ##########eval training if self.eval_train: pred_logits=outputs["pred_logits"] pred_boxes=outputs["pred_boxes"] pred_masks=outputs["pred_masks"]>0 num_total=0 num_correct=0 mask_iou=0.0 # total_union=0.0 # total_intersection=0.0 scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs] for i in range(len(pred_logits)): matched_idx=torch.argmax(pred_logits[i],dim=0) matched_boxes=pred_boxes[i][matched_idx] matched_masks=pred_masks[i][matched_idx] gt_boxes_=targets[i]['boxes'] gt_masks_=targets[i]['masks'] gt_ground_labels=targets[i]['labels'] gt_ground_labels_=[] for lb in gt_ground_labels: gt_ground_labels_.extend(lb) max_lb=max(gt_ground_labels_) lb2gt_idx=dict() for lb in range(max_lb+1): lb2gt_idx[lb]=[] for idx,lbs in enumerate(gt_ground_labels): for lb in lbs: lb2gt_idx[lb].append(idx) for lb in range(max_lb+1): pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None]) gt_boxes=box_ops.box_cxcywh_to_xyxy(gt_boxes_[lb2gt_idx[lb]]) pred_mask=matched_masks[lb] gt_mask=gt_masks_[lb2gt_idx[lb]][0] pred_mask = F.interpolate( pred_mask[None,None].float(), size=(gt_mask.shape[-2], gt_mask.shape[-1]), mode="bilinear", align_corners=False, )[0,0]>0.5 if len(gt_boxes)==0: continue mask_iou+=float(torch.sum(pred_mask*gt_mask)/torch.sum(torch.logical_or(pred_mask,gt_mask))) self.total_union+=float(torch.sum(torch.logical_or(pred_mask,gt_mask)))/scale_factor[i]**2 self.total_intersection+=float(torch.sum(pred_mask*gt_mask))/scale_factor[i]**2 # self.mask_iou+=mask_iou if box_ops.box_iou(pred_box,gt_boxes)[0].max()>0.5: num_correct+=1 else: pass num_total+=1 print(f"{data_type} cIoU:" ,self.total_intersection/self.total_union) name_correct='num_correct_'+data_type name_total='num_total_'+data_type try: gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,num_correct) # self.num_correct+=sum(gathered_list) num_correct_value=getattr(self,name_correct) setattr(self,name_correct,num_correct_value+sum(gathered_list)) gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,num_total) # self.num_total+=sum(gathered_list) num_total_value=getattr(self,name_total) setattr(self,name_total,num_total_value+sum(gathered_list)) gathered_list=[None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_list,mask_iou) # self.mask_iou+=sum(gathered_list) if torch.distributed.get_rank()==0: print(f"{data_type} acc: ",getattr(self, name_correct) / getattr(self, name_total)) # print("mask_iou: ",self.mask_iou/self.num_total) except Exception as e: # self.num_correct+=num_correct # self.num_total+=num_total num_correct_value = getattr(self, name_correct) setattr(self, name_correct, num_correct_value + num_correct) num_total_value = getattr(self, name_total) setattr(self, name_total, num_total_value + num_total) try: print(f"{data_type} rank{torch.distributed.get_rank()} acc: ", getattr(self, name_correct) / getattr(self, name_total)) except Exception as e: print(f"{data_type} acc: ", getattr(self, name_correct) / getattr(self, name_total)) ########################### # bipartite matching-based loss self.criterion.default_text_embeddings = default_text_embeddings losses = self.criterion(outputs, targets, mask_dict, task=task) for k in list(losses.keys()): if k in self.criterion.weight_dict: losses[k] *= self.criterion.weight_dict[k] else: # remove this loss if not specified in `weight_dict` losses.pop(k) return losses else: outputs, _ = self.sem_seg_head(features) mask_cls_results = outputs["pred_logits"] mask_box_results = outputs["pred_boxes"] if 'seg' in task: if task == 'seg': self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True if task == 'inst_seg': self.semantic_on = self.panoptic_on = False self.instance_on = True self.sem_seg_postprocess_before_inference = True if task == 'sem_pan_seg': self.semantic_on = self.panoptic_on = True self.instance_on = False self.sem_seg_postprocess_before_inference = True if task == 'inst_pan_seg': self.instance_on = self.panoptic_on = True self.semantic_on = False self.sem_seg_postprocess_before_inference = True if task == 'sem_seg': self.instance_on = self.panoptic_on = False self.semantic_on = True self.sem_seg_postprocess_before_inference = True mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) else: self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False self.instance_on = True mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results) del outputs processed_results = [] for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) processed_results.append({}) new_size = (images.tensor.shape[-2], images.tensor.shape[-1]) # padded size (divisible to 32) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width ) mask_cls_result = mask_cls_result.to(mask_pred_result) # semantic segmentation inference if self.semantic_on: r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result) if not self.sem_seg_postprocess_before_inference: r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) processed_results[-1]["sem_seg"] = r # panoptic segmentation inference if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) processed_results[-1]["panoptic_seg"] = panoptic_r # instance segmentation inference if self.instance_on: mask_box_result = mask_box_result.to(mask_pred_result) height = new_size[0]/image_size[0]*height width = new_size[1]/image_size[1]*width mask_box_result = self.box_postprocess(mask_box_result, height, width) instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result) processed_results[-1]["instances"] = instance_r del mask_pred_results return processed_results def prepare_targets(self, targets, images, task='seg'): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) if task != 'det': gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks else: padded_masks = None new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, "boxes":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy } ) return new_targets def semantic_inference(self, mask_cls, mask_pred): # if use cross-entropy loss in training, evaluate with softmax if self.semantic_ce_loss: mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. else: T = self.pano_temp mask_cls = mask_cls.sigmoid() if self.transform_eval: mask_cls = F.softmax(mask_cls / T, dim=-1) # already sigmoid mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def panoptic_inference(self, mask_cls, mask_pred): # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. prob = 0.5 T = self.pano_temp scores, labels = mask_cls.sigmoid().max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) # added process if self.transform_eval: scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: # We didn't detect any mask :( return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= prob).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= prob) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def instance_inference(self, mask_cls, mask_pred, mask_box_result): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = mask_cls.sigmoid() # [100, 80] labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # select 100 labels_per_image = labels[topk_indices] topk_indices = topk_indices // self.sem_seg_head.num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): # print(i, len(keep), lab) keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values() scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() # half mask box half pred box mask_box_result = mask_box_result[topk_indices] if self.panoptic_on: mask_box_result = mask_box_result[keep] result.pred_boxes = Boxes(mask_box_result) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob if self.sem_seg_postprocess_before_inference: mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) else: mask_scores_per_image = 1.0 # labels_per_image = labels_per_image + 1 # HACK for o365 classification if self.focus_on_box: mask_scores_per_image = 1.0 result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def box_postprocess(self, out_bbox, img_h, img_w): # postprocess box height and width boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) scale_fct = torch.tensor([img_w, img_h, img_w, img_h]) scale_fct = scale_fct.to(out_bbox) boxes = boxes * scale_fct return boxes def forward_eval(self, batched_inputs, text_embeddings): # import ipdb; ipdb.set_trace() # print("Num images per batch:",len(batched_inputs['flickr'])) if self.training: raise NotImplementedError else: self.criterion.conversation=False box_results, seg_results = self.forward_inner_eval( batched_inputs, task='seg', default_text_embeddings=text_embeddings, ) return box_results, seg_results def forward_inner_eval(self, batched_inputs, task='seg',default_text_embeddings=None): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) matching_threshold = batched_inputs[0]["matching_threshold"] if "matching_threshold" in batched_inputs[0].keys() else None features = self.backbone(images.tensor) # features={k:v.to(torch.bfloat16) for k,v in features.items()} # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None default_text_embeddings_ = [default_text_embeddings[0].float(), default_text_embeddings[1]] outputs, mask_dict = self.sem_seg_head(features, targets=None, task=task,default_text_embeddings=default_text_embeddings_) ##########eval training pred_logits=outputs["pred_logits"] pred_boxes=outputs["pred_boxes"] pred_masks=outputs["pred_masks"]>0 # scale_factor=[1024./max(data['height'],data['width']) for data in batched_inputs] matched_pred_boxes = [] matched_pred_masks = [] for i in range(len(pred_logits)): if len(pred_logits) > 1: raise NotImplementedError num_grounding = pred_logits.shape[2] for gd_idx in range(num_grounding): if matching_threshold is None: matched_idx = torch.argmax(pred_logits[i, :, gd_idx],dim=0) matched_boxes = pred_boxes[i][matched_idx] matched_boxes = matched_boxes[None, :] else: matched_idx = torch.where(pred_logits[i, :, gd_idx].softmax(dim=0) > matching_threshold)[0] # print(matched_idx, pred_logits[i, :, gd_idx].softmax(dim=0)[matched_idx]) if matched_idx.shape[0] == 0: #* if there is no one object satisfy threshold, then select the best matched one. matched_boxes = pred_boxes.new_zeros((1, 4)) matched_masks = pred_boxes.new_zeros((1, 256, 256)) else: matched_boxes = pred_boxes[i][matched_idx] matched_masks = pred_masks[i][matched_idx] # matched_masks=pred_masks[i][matched_idx] matched_boxes_processed = [] for lb in range(matched_boxes.shape[0]): pred_box=box_ops.box_cxcywh_to_xyxy(matched_boxes[lb][None]) matched_boxes_processed.append(pred_box) matched_pred_boxes.append(torch.cat(matched_boxes_processed, dim=0)) matched_pred_masks.append(matched_masks) return matched_pred_boxes, matched_pred_masks @register_model def get_segmentation_model(cfg, **kwargs): return OpenSeeD(cfg) ================================================ FILE: llava/model/openseed/architectures/openseed_model_decouple_train.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ from typing import Tuple import torch from torch import nn from torch.nn import functional as F from .registry import register_model from ..utils import configurable, box_ops, get_class_names from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher, SetCriterion from ..language import build_language_encoder from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import random import json class OpenSeeD(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, criterion_coco=None, criterion_o365=None, split_panno=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata if size_divisibility < 0: # use backbone size_divisibility if not set size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.split_panno=split_panno # additional args self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on} self.criterion_coco=criterion_coco self.criterion_o365=criterion_o365 print("self.task_switch ", self.task_switch) # HACK for only two datasets for seg and det if coco_on: task = 'seg' if not coco_mask_on: task = 'det' self.train_class_names[task] = get_class_names(train_dataset_name[0], background=background) self.train_class_names[task] = [a.replace("-merged", "").replace("-other", "").replace("-stuff", "") for a in self.train_class_names[task]] train_class_names = [] for name in self.train_class_names[task]: names = name.split('-') if len(names) > 1: assert len(names) == 2 train_class_names.append(names[1] + ' ' + names[0]) else: train_class_names.append(name) self.train_class_names[task] = train_class_names if o365_on and len(train_dataset_name)>1: for dt in train_dataset_name: if "o365" in train_dataset_name or "object365" in train_dataset_name: break self.train_class_names['det'] = get_class_names(dt, background=background) self.train_class_names['det'] = [a.lower().split('/') for a in self.train_class_names['det']] if not self.semantic_on: assert self.sem_seg_postprocess_before_inference @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] # Loss parameters: deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0":box_weight,"loss_giou_0":giou_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k!="loss_mask" and k!="loss_dice" }) dn_losses=["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses=["labels", "masks", "boxes"] else: dn_losses=[] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i+1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["labels", "masks","boxes"] else: losses = ["labels", "masks"] # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10), 'box': dec_cfg.get('TOP_DETECTION_LAYERS', 10)} # building criterion criterion_coco = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'], ) criterion_o365 = SetCriterion( enc_cfg.get('NUM_CLASSES_O365', 365), matcher=matcher, weight_dict=weight_dict, top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][ 'PANOPTIC_ON'], ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) lang_encoder = build_language_encoder(cfg) sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), lang_encoder, extra=extra) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion_coco": criterion_coco, "criterion_o365": criterion_o365, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER']['SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "split_panno": dec_cfg.get('PANO_CRITERION', True), } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs, inference_task='seg'): # import ipdb; ipdb.set_trace() if self.training: losses = {} if self.task_switch['coco'] and 'coco' in batched_inputs: self.criterion_coco.num_classes = 133 if 'pano' in self.train_dataset_name[0] else 80 # self.criterion.num_classes = 133 task = 'seg' if not self.coco_mask_on: task = 'det' # import ipdb; ipdb.set_trace() losses_coco = self.forward_seg(batched_inputs['coco'], task=task) new_losses_coco = {} for key, value in losses_coco.items(): new_losses_coco['coco.'+str(key)] = losses_coco[key] losses.update(new_losses_coco) if self.task_switch['o365'] and 'o365' in batched_inputs: self.criterion_o365.num_classes = 365 losses_o365 = self.forward_seg(batched_inputs['o365'], task='det') new_losses_o365 = {} for key, value in losses_o365.items(): new_losses_o365['o365.'+str(key)] = losses_o365[key] losses.update(new_losses_o365) return losses else: processed_results = self.forward_seg(batched_inputs, task=inference_task) return processed_results def forward_seg(self, batched_inputs, task='seg'): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": per-region ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict has the results for one image. The dict contains the following keys: * "sem_seg": A Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel. * "panoptic_seg": A tuple that represent panoptic output panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict]): Describe each segment in `panoptic_seg`. Each dict contains keys "id", "category_id", "isthing". """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) if self.training: if task == "det" and self.task_switch['o365']: train_class_names = [random.sample(name, 1)[0] for name in self.train_class_names['det']] else: train_class_names = self.train_class_names[task] self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(train_class_names, is_eval=False) # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images, task=task) else: targets = None outputs, mask_dict = self.sem_seg_head(features, targets=targets, task=task) # bipartite matching-based loss if task=='det': criterion=self.criterion_o365 losses = self.criterion_o365(outputs, targets, mask_dict, task=task) else: criterion=self.criterion_coco losses = self.criterion_coco(outputs, targets, mask_dict, task=task) # else for k in list(losses.keys()): if k in criterion.weight_dict: losses[k] *= criterion.weight_dict[k] else: # remove this loss if not specified in `weight_dict` losses.pop(k) return losses else: outputs, _ = self.sem_seg_head(features) mask_cls_results = outputs["pred_logits"] mask_box_results = outputs["pred_boxes"] if 'seg' in task: if task == 'seg': self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = self.instance_on = True if task == 'pan_seg': self.semantic_on = self.instance_on = False self.panoptic_on = True self.sem_seg_postprocess_before_inference = True if task == 'inst_seg': self.semantic_on = self.panoptic_on = False self.instance_on = True self.sem_seg_postprocess_before_inference = True if task == 'sem_pan_seg': self.semantic_on = self.panoptic_on = True self.instance_on = False self.sem_seg_postprocess_before_inference = True if task == 'inst_pan_seg': self.instance_on = self.panoptic_on = True self.semantic_on = False self.sem_seg_postprocess_before_inference = True if task == 'sem_seg': self.instance_on = self.panoptic_on = False self.semantic_on = True self.sem_seg_postprocess_before_inference = True mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) else: self.semantic_on = self.panoptic_on = self.sem_seg_postprocess_before_inference = False self.instance_on = True mask_pred_results = torch.zeros(mask_box_results.shape[0], mask_box_results.shape[1],2, 2).to(mask_box_results) del outputs processed_results = [] for mask_cls_result, mask_pred_result, mask_box_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, mask_box_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) processed_results.append({}) new_size = (images.tensor.shape[-2], images.tensor.shape[-1]) # padded size (divisible to 32) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width ) mask_cls_result = mask_cls_result.to(mask_pred_result) # semantic segmentation inference if self.semantic_on: r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result) if not self.sem_seg_postprocess_before_inference: r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) processed_results[-1]["sem_seg"] = r # panoptic segmentation inference if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) processed_results[-1]["panoptic_seg"] = panoptic_r # instance segmentation inference if self.instance_on: mask_box_result = mask_box_result.to(mask_pred_result) height = new_size[0]/image_size[0]*height width = new_size[1]/image_size[1]*width mask_box_result = self.box_postprocess(mask_box_result, height, width) instance_r = retry_if_cuda_oom(self.instance_inference)( mask_cls_result[:self.sem_seg_head.predictor.num_queries_test], mask_pred_result[:self.sem_seg_head.predictor.num_queries_test], mask_box_result[:self.sem_seg_head.predictor.num_queries_test], True) # instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, mask_box_result) processed_results[-1]["instances"] = instance_r del mask_pred_results return processed_results def prepare_targets(self, targets, images, task='seg'): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) if task != 'det': gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks else: padded_masks = None new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, "boxes":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy } ) return new_targets def semantic_inference(self, mask_cls, mask_pred): # if use cross-entropy loss in training, evaluate with softmax if self.semantic_ce_loss: mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg # if use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. else: T = self.pano_temp mask_cls = mask_cls.sigmoid() if self.transform_eval: mask_cls = F.softmax(mask_cls / T, dim=-1) # already sigmoid mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def panoptic_inference(self, mask_cls, mask_pred): # As we use focal loss in training, evaluate with sigmoid. As sigmoid is mainly for detection and not sharp # enough for semantic and panoptic segmentation, we additionally use use softmax with a temperature to # make the score sharper. prob = 0.5 T = self.pano_temp scores, labels = mask_cls.sigmoid().max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) # added process if self.transform_eval: scores, labels = F.softmax(mask_cls.sigmoid() / T, dim=-1).max(-1) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: # We didn't detect any mask :( return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= prob).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= prob) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def instance_inference(self, mask_cls, mask_pred, mask_box_result,split_anno): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = mask_cls.sigmoid() # [100, 80] if split_anno: labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat( self.sem_seg_head.predictor.num_queries_test, 1).flatten(0, 1) else: labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # select 100 labels_per_image = labels[topk_indices] topk_indices = topk_indices // self.sem_seg_head.num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): # print(i, len(keep), lab) keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values() scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() # half mask box half pred box mask_box_result = mask_box_result[topk_indices] if self.panoptic_on: mask_box_result = mask_box_result[keep] result.pred_boxes = Boxes(mask_box_result) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob if self.sem_seg_postprocess_before_inference: mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) else: mask_scores_per_image = 1.0 # labels_per_image = labels_per_image + 1 # HACK for o365 classification if self.focus_on_box: mask_scores_per_image = 1.0 result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def box_postprocess(self, out_bbox, img_h, img_w): # postprocess box height and width boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) scale_fct = torch.tensor([img_w, img_h, img_w, img_h]) scale_fct = scale_fct.to(out_bbox) boxes = boxes * scale_fct return boxes @register_model def get_segmentation_model(cfg, **kwargs): return OpenSeeD(cfg) ================================================ FILE: llava/model/openseed/architectures/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/backbone/__init__.py ================================================ from .build import build_backbone from .focal import * from .focal_dw import * from .swin import * from .backbone import * ================================================ FILE: llava/model/openseed/backbone/backbone.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch.nn as nn from detectron2.modeling import ShapeSpec # from ..layers import ShapeSpec __all__ = ["Backbone"] class Backbone(nn.Module): """ Abstract base class for network backbones. """ def __init__(self): """ The `__init__` method of any subclass can specify its own set of arguments. """ super().__init__() def forward(self): """ Subclasses must override this method, but adhere to the same return type. Returns: dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor """ pass @property def size_divisibility(self) -> int: """ Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required. """ return 0 def output_shape(self): """ Returns: dict[str->ShapeSpec] """ # this is a backward-compatible default return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } ================================================ FILE: llava/model/openseed/backbone/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .backbone import * def build_backbone(config, **kwargs): model_name = config['MODEL']['BACKBONE']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/backbone/focal.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) if self.use_postln: x = self.norm1(x) # FFN x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 2; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/openseed/backbone/focal_dw.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw1(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw2(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) if not self.use_postln: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(x)) x = self.norm2(x) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, use_pre_norm=False, ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False, use_pre_norm=use_pre_norm ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W # class PatchEmbed(nn.Module): # r""" Image to Patch Embedding # Args: # img_size (int): Image size. Default: 224. # patch_size (int): Patch token size. Default: 4. # in_chans (int): Number of input image channels. Default: 3. # embed_dim (int): Number of linear projection output channels. Default: 96. # norm_layer (nn.Module, optional): Normalization layer. Default: None # """ # def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, # use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False): # super().__init__() # patch_size = to_2tuple(patch_size) # patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # self.img_size = img_size # self.patch_size = patch_size # self.patches_resolution = patches_resolution # self.num_patches = patches_resolution[0] * patches_resolution[1] # self.in_chans = in_chans # self.embed_dim = embed_dim # self.use_pre_norm = use_pre_norm # if use_conv_embed: # # if we choose to use conv embedding, then we treat the stem and non-stem differently # if is_stem: # kernel_size = 7; padding = 3; stride = 4 # else: # kernel_size = 3; padding = 1; stride = 2 # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) # else: # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # if self.use_pre_norm: # if norm_layer is not None: # self.norm = norm_layer(in_chans) # else: # self.norm = None # else: # if norm_layer is not None: # self.norm = norm_layer(embed_dim) # else: # self.norm = None # def forward(self, x): # B, C, H, W = x.shape # # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # if self.use_pre_norm: # if self.norm is not None: # x = x.flatten(2).transpose(1, 2) # B Ph*Pw C # x = self.norm(x).transpose(1, 2).view(B, C, H, W) # x = self.proj(x).flatten(2).transpose(1, 2) # else: # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C # if self.norm is not None: # x = self.norm(x) # return x # def flops(self): # Ho, Wo = self.patches_resolution # flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # if self.norm is not None: # flops += Ho * Wo * self.embed_dim # return flops class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.use_pre_norm = use_pre_norm if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 3; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if self.use_pre_norm: if norm_layer is not None: self.norm = norm_layer(in_chans) else: self.norm = None else: if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" B, C, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) if self.use_pre_norm: if self.norm is not None: x = x.flatten(2).transpose(1, 2) # B Ph*Pw C x = self.norm(x).transpose(1, 2).view(B, C, H, W) x = self.proj(x) else: x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_pre_norms=[False, False, False, False], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_pre_norm=use_pre_norms[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # self.norm = norm_layer(num_features[-1]) # add a norm layer for each output for i_layer in self.out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/openseed/backbone/registry.py ================================================ _model_entrypoints = {} def register_backbone(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/backbone/swin.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import Backbone, ShapeSpec from detectron2.utils.file_io import PathManager from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" # HACK model will not upsampling # if min([H, W]) <= self.window_size: # if window size is larger than input resolution, we don't partition windows # self.shift_size = 0 # self.window_size = min([H,W]) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ).type(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): relative_position_bias_table_pretrained = v relative_position_bias_table_current = model_dict[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): absolute_pos_embed_pretrained = v absolute_pos_embed_current = model_dict[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((1, L1, C1), (1, L2, C2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint): super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) self._out_features = cfg['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features)) return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in feature_names } @property def size_divisibility(self): return 32 @register_backbone def get_swin_backbone(cfg): swin_cfg = cfg['MODEL']['BACKBONE']['SWIN'] pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE'] patch_size = swin_cfg['PATCH_SIZE'] in_chans = 3 embed_dim = swin_cfg['EMBED_DIM'] depths = swin_cfg['DEPTHS'] num_heads = swin_cfg['NUM_HEADS'] window_size = swin_cfg['WINDOW_SIZE'] mlp_ratio = swin_cfg['MLP_RATIO'] qkv_bias = swin_cfg['QKV_BIAS'] qk_scale = swin_cfg['QK_SCALE'] drop_rate = swin_cfg['DROP_RATE'] attn_drop_rate = swin_cfg['ATTN_DROP_RATE'] drop_path_rate = swin_cfg['DROP_PATH_RATE'] norm_layer = nn.LayerNorm ape = swin_cfg['APE'] patch_norm = swin_cfg['PATCH_NORM'] use_checkpoint = swin_cfg['USE_CHECKPOINT'] out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3]) swin = D2SwinTransformer( swin_cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] with PathManager.open(filename, "rb") as f: ckpt = torch.load(f, map_location='cpu')['model'] swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return swin ================================================ FILE: llava/model/openseed/body/__init__.py ================================================ from .build import build_openseed_head ================================================ FILE: llava/model/openseed/body/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .openseed_head import * def build_openseed_head(config, *args, **kwargs): model_name = config['MODEL']['HEAD'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') body = model_entrypoints(model_name)(config, *args, **kwargs) return body ================================================ FILE: llava/model/openseed/body/decoder/__init__.py ================================================ from .build import build_decoder from .openseed_decoder import * from .openseed_decoder_decouple import * ================================================ FILE: llava/model/openseed/body/decoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_decoder(config, *args, **kwargs): model_name = config['MODEL']['DECODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/openseed/body/decoder/modules.py ================================================ from typing import Optional import torch from torch import nn, Tensor from torch.nn import functional as F from timm.models.layers import trunc_normal_ from detectron2.layers import Conv2d import fvcore.nn.weight_init as weight_init class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt, avg_attn def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) return tgt, avg_attn def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x ================================================ FILE: llava/model/openseed/body/decoder/openseed_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class OpenSeeDDecoder(nn.Module): @configurable def __init__( self, # lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.semantic_ce_loss = semantic_ce_loss # learnable query features if not two_stage or self.learn_tgt: self.query_feat = nn.Embedding(num_queries, hidden_dim) if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries, 4) if two_stage: self.enc_output = nn.Linear(hidden_dim, hidden_dim) self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj # self.lang_encoder = lang_encoder self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed self.logit_scale = nn.Parameter(torch.ones([])) self.default_text_embeddings = None #for grounding tokens self.default_text_embeddings_mask = None #for grounding tokens @classmethod def from_config(cls, cfg, in_channels, mask_classification, extra): ret = {} ret["in_channels"] = in_channels # ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] # Transformer parameters: ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = torch.gather(self.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def compute_similarity(self, v_emb,name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = self.default_text_embeddings output = self.logit_scale.exp() * v_emb @ t_emb.transpose(1, 2) output[~self.default_text_embeddings_mask[:,None].repeat(1,output.shape[1],1)] = -100. # output = v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg',default_text_embeddings=None, extra={}): """ task: seg/det """ self.default_text_embeddings,self.default_text_embeddings_mask=default_text_embeddings self.dn="no" assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] if self.two_stage: output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory_ = output_memory @ self.class_embed enc_outputs_class_unselected = self.compute_similarity(output_memory_,default_text_embeddings) enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # enc_outputs_class_unselected = self.class_embed(output_memory) enc_outputs_coord_unselected = self._bbox_embed( output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid refpoint_embed = refpoint_embed_undetach.detach() tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg) tgt = tgt_undetach.detach() if self.learn_tgt: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) interm_outputs=dict() interm_outputs['pred_logits'] = outputs_class interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() interm_outputs['pred_masks'] = outputs_mask if self.initialize_box_type != 'no' and do_seg: # convert masks into boxes to better initialize box in the decoder assert self.initial_pred flaten_mask = outputs_mask.detach().flatten(0, 1) h, w = outputs_mask.shape[-2:] if self.initialize_box_type == 'bitmask': # slower, but more accurate refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() elif self.initialize_box_type == 'mask2box': # faster conversion refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() else: assert NotImplementedError refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], dtype=torch.float).cuda() refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) refpoint_embed = inverse_sigmoid(refpoint_embed) elif not self.two_stage: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1) tgt_mask = None mask_dict = None if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0]) if mask_dict is not None: tgt=torch.cat([input_query_label, tgt],dim=1) # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) if self.dn != "no" and self.training and mask_dict is not None: refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1) # print('tgt',tgt.dtype) # print('src_flatten',src_flatten.dtype) # print('refpoint',refpoint_embed.dtype) tgt=tgt.to(src_flatten.dtype) refpoint_embed=refpoint_embed.to(src_flatten.dtype) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction if self.initial_pred: out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) assert len(predictions_class) == self.num_layers + 1 else: out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class =torch.stack(predictions_class) predictions_class, out_boxes,predictions_mask=\ self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask,out_boxes ) } if self.two_stage: out['interm_outputs'] = interm_outputs return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.compute_similarity(class_embed,self.default_text_embeddings) outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, mask_classification, extra): return OpenSeeDDecoder(cfg, in_channels, mask_classification, extra) ================================================ FILE: llava/model/openseed/body/decoder/openseed_decoder_decouple.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class MaskDINODecoder(nn.Module): @configurable def __init__( self, lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, no_update=False, num_queries_stuff=100, num_queries_test=300, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.num_queries_test=num_queries_test self.semantic_ce_loss = semantic_ce_loss self.no_update=no_update # learnable query features # if not two_stage or self.learn_tgt: self.num_queries_stuff=num_queries_stuff self.query_feat = nn.Embedding(num_queries_stuff, hidden_dim) # if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries_stuff, 4) if two_stage: self.enc_output = nn.Linear(hidden_dim, hidden_dim) self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj self.lang_encoder = lang_encoder self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed @classmethod def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): ret = {} ret["in_channels"] = in_channels ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] ret["num_queries_test"] = dec_cfg.get('NUM_OBJECT_QUERIES_TEST',300) # Transformer parameters: ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] ret["no_update"]=dec_cfg.get("no_update",False) return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size,task="other"): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if task=="cls": scalar=1 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) if task=="cls": known_labels_expaned=torch.zeros_like(known_labels_expaned) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() if task=="cls": input_label_embed=self.cls_emb(m) else: input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries+self.num_queries_stuff attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def forward_cls(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, extra={}): """ task: seg/det """ assert len(x) == self.num_feature_levels do_seg = False# if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx = self.num_feature_levels - 1 - i bs, c, h, w = x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] # if self.two_stage: # output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) # output_memory = self.enc_output_norm(self.enc_output(output_memory)) # output_memory_ = output_memory @ self.class_embed # enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_) # enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # # enc_outputs_class_unselected = self.class_embed(output_memory) # enc_outputs_coord_unselected = self._bbox_embed( # output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid # topk = self.num_queries # topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, # topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid # refpoint_embed = refpoint_embed_undetach.detach() # # tgt_undetach = torch.gather(output_memory, 1, # topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid # outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, # do_seg) # tgt = tgt_undetach.detach() # if self.learn_tgt: # tgt = self.query_feat.weight[None].repeat(bs, 1, 1) # interm_outputs = dict() # interm_outputs['pred_logits'] = outputs_class # interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() # interm_outputs['pred_masks'] = outputs_mask # # if self.initialize_box_type != 'no' and do_seg: # # convert masks into boxes to better initialize box in the decoder # assert self.initial_pred # flaten_mask = outputs_mask.detach().flatten(0, 1) # h, w = outputs_mask.shape[-2:] # if self.initialize_box_type == 'bitmask': # slower, but more accurate # refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() # elif self.initialize_box_type == 'mask2box': # faster conversion # refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() # else: # assert NotImplementedError # refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], # dtype=torch.float).cuda() # refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) # refpoint_embed = inverse_sigmoid(refpoint_embed) # elif not self.two_stage: # tgt = self.query_feat.weight[None].repeat(bs, 1, 1) # refpoint_embed = self.query_embed.weight[None].repeat(bs, 1, 1) tgt_mask = None mask_dict = None # if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0],task="cls") # if mask_dict is not None: tgt = input_query_label refpoint_embed = input_query_bbox # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # if self.dn != "no" and self.training and mask_dict is not None: # tgt=tgt.float() hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=None, no_update=self.no_update, ) for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, ( self.training or (i == len(hs) - 1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction # if self.initial_pred: # out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) # assert len(predictions_class) == self.num_layers + 1 # else: # out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class = torch.stack(predictions_class) # predictions_class, out_boxes, predictions_mask = \ # self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss # if do_seg: # predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], # 'pred_boxes': out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask ) } # if self.two_stage: # out['interm_outputs'] = interm_outputs return out, None def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det """ assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_mask = [] # if self.two_stage: output_memory, output_proposals = gen_encoder_output_proposals(src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory_ = output_memory @ self.class_embed enc_outputs_class_unselected = self.lang_encoder.compute_similarity(output_memory_) enc_outputs_class_unselected[output_proposals.sum(-1).isinf()] = float("-inf") # enc_outputs_class_unselected = self.class_embed(output_memory) enc_outputs_coord_unselected = self._bbox_embed( output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries if self.training else self.num_queries_test topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.hidden_dim)) # unsigmoid tgt_stuff = self.query_feat.weight[None].repeat(bs, 1, 1) refpoint_embed_stuff = self.query_embed.weight[None].repeat(bs, 1, 1) # if not (self.) tgt_undetach=torch.cat([tgt_undetach,tgt_stuff],dim=1) refpoint_embed_undetach=torch.cat([refpoint_embed_undetach,refpoint_embed_stuff],dim=1) refpoint_embed = refpoint_embed_undetach.detach() outputs_class, outputs_mask = self.forward_prediction_heads(tgt_undetach.transpose(0, 1), mask_features, do_seg) tgt = tgt_undetach.detach() if self.learn_tgt: tgt = self.query_feat.weight[None].repeat(bs, 1, 1) interm_outputs=dict() interm_outputs['pred_logits'] = outputs_class interm_outputs['pred_boxes'] = refpoint_embed_undetach.sigmoid() interm_outputs['pred_masks'] = outputs_mask if self.initialize_box_type != 'no' and do_seg: # convert masks into boxes to better initialize box in the decoder assert self.initial_pred flaten_mask = outputs_mask.detach().flatten(0, 1) h, w = outputs_mask.shape[-2:] if self.initialize_box_type == 'bitmask': # slower, but more accurate refpoint_embed = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor.cuda() elif self.initialize_box_type == 'mask2box': # faster conversion refpoint_embed = box_ops.masks_to_boxes(flaten_mask > 0).cuda() else: assert NotImplementedError refpoint_embed = box_ops.box_xyxy_to_cxcywh(refpoint_embed) / torch.as_tensor([w, h, w, h], dtype=torch.float).cuda() refpoint_embed = refpoint_embed.reshape(outputs_mask.shape[0], outputs_mask.shape[1], 4) refpoint_embed = inverse_sigmoid(refpoint_embed) # elif not self.two_stage: tgt_mask = None mask_dict = None if self.dn != "no" and self.training: assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn(targets, None, None, x[0].shape[0]) if mask_dict is not None: tgt=torch.cat([input_query_label, tgt],dim=1) # direct prediction from the matching and denoising part in the begining if self.initial_pred: outputs_class, outputs_mask = self.forward_prediction_heads(tgt.transpose(0, 1), mask_features, self.training and do_seg) if not (task == 'seg' or not self.training): outputs_class=outputs_class[:,:-self.num_queries_stuff] # outputs_mask=outputs_mask[:,:-self.num_queries_stuff] predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) if self.dn != "no" and self.training and mask_dict is not None: refpoint_embed=torch.cat([input_query_bbox,refpoint_embed],dim=1) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) if not (task=='seg' or not self.training): hs=[hs_[:,:-self.num_queries_stuff] for hs_ in hs] references=[references_[:,:-self.num_queries_stuff] for references_ in references] refpoint_embed=refpoint_embed[:,:-self.num_queries_stuff] for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) # iteratively box prediction if self.initial_pred: out_boxes = self.pred_box(references, hs, refpoint_embed.sigmoid()) assert len(predictions_class) == self.num_layers + 1 else: out_boxes = self.pred_box(references, hs) if mask_dict is not None: predictions_mask = None if not do_seg else torch.stack(predictions_mask) predictions_class =torch.stack(predictions_class) predictions_class, out_boxes,predictions_mask=\ self.dn_post_process(predictions_class, out_boxes, mask_dict, predictions_mask) predictions_class = list(predictions_class) if predictions_mask is None: predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model predictions_class[-1] = predictions_class[-1] + 0.0 * self.lang_mapper[0, 0] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask,out_boxes ) } if self.two_stage: out['interm_outputs'] = interm_outputs return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.lang_encoder.compute_similarity(class_embed) outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks, out_boxes=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: if outputs_seg_masks is None: return [ {"pred_logits": a} for a in outputs_class[:-1] ] else: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra) ================================================ FILE: llava/model/openseed/body/decoder/registry.py ================================================ _model_entrypoints = {} def register_decoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/decoder/utils/__init__.py ================================================ from .utils import * ================================================ FILE: llava/model/openseed/body/decoder/utils/dino_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Optional, List, Union import torch from torch import nn, Tensor from torch.cuda.amp import autocast from .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid from ...encoder.ops.modules import MSDeformAttn class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=True, num_feature_levels=1, deformable_decoder=True, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=True, dec_layer_share=False, dec_layer_dropout_prob=None, task_switch=None, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.task_switch = task_switch self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # misc extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys(): reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0) output = torch.cat((output, extra['grounding_tokens']), dim=0) for layer_id, layer in enumerate(self.layers): # preprocess ref points if self.training and self.decoder_query_perturber is not None and layer_id != 0: reference_points = self.decoder_query_perturber(reference_points) reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 # print('reference_points_input', reference_points_input.dtype) # print('memory', memory.dtype) reference_points_input=reference_points_input.to(memory.dtype) query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2 # print('query_sine_embed', query_sine_embed.dtype) raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, task_switch=self.task_switch, extra=extra, ) # grounding language token reference point will not update and saved if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _reference_points = reference_points[-extra['grounding_len']:] reference_points = reference_points[:-extra['grounding_len']] _output = output[-extra['grounding_len']:] output = output[:-extra['grounding_len']] # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) # add back grounding language token if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': reference_points = torch.cat((reference_points, _reference_points)) output = torch.cat((output, _output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, key_aware_type=None, ): super().__init__() # cross attention if use_deformable_box_attn: raise NotImplementedError else: self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt @autocast(enabled=False) def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # misc task_switch: Optional[Tensor] = {}, # extra information extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # exclude grounding token for cross attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _grounding_lang_tokens = tgt[-extra['grounding_len']:,] _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,] _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,] tgt = tgt[:-extra['grounding_len'],] tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],] tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],] # cross attention if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better tgt = tgt + self.dropout1(tgt2) # add back grounding token for self attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': tgt = torch.cat((tgt, _grounding_lang_tokens)) tgt = self.norm1(tgt) tgt = self.forward_ffn(tgt) # ffn return tgt ================================================ FILE: llava/model/openseed/body/decoder/utils/utils.py ================================================ import torch import copy from torch import nn, Tensor import os import math import torch.nn.functional as F from torch import nn class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) return output_memory, output_proposals def gen_sineembed_for_position(pos_tensor, dim=128): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * (dim_t // 2) / dim) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos.to(pos_tensor.dtype) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == "prelu": return nn.PReLU() if activation == "selu": return F.selu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) ================================================ FILE: llava/model/openseed/body/encoder/__init__.py ================================================ from .build import build_encoder ================================================ FILE: llava/model/openseed/body/encoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .transformer_encoder_fpn import * from .encoder_deform import * def build_encoder(config, *args, **kwargs): model_name = config['MODEL']['ENCODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/openseed/body/encoder/encoder_deform.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast from detectron2.layers import Conv2d, ShapeSpec, get_norm # from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_encoder from ...utils import configurable from ...modules import PositionEmbeddingSine from ..transformer_blocks import _get_clones, _get_activation_fn from .ops.modules import MSDeformAttn from torch.utils import checkpoint # MSDeformAttn Transformer encoder in deformable detr class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4,): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, use_ckpt=False): enable_mask=0 if masks is not None: for src in srcs: if src.size(2)%32 or src.size(3)%32: enable_mask = 1 if enable_mask==0: masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt) return memory, spatial_shapes, level_start_index class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False): output = src reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): if use_ckpt: output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) else: output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class OpenSeeDEncoder(nn.Module): """ This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models. """ @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, # deformable transformer encoder args transformer_in_features: List[str], common_stride: int, num_feature_levels: int, total_num_feature_levels: int, feature_order: str, use_ckpt=False, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers num_feature_levels: feature scales used total_num_feature_levels: total feautre scales used (include the downsampled features) feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first. """ super().__init__() self.use_ckpt = use_ckpt transformer_input_shape = { k: v for k, v in input_shape.items() if k in transformer_in_features } # this is the input shape of pixel decoder input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" self.feature_strides = [v.stride for k, v in input_shape] self.feature_channels = [v.channels for k, v in input_shape] self.feature_order = feature_order if feature_order == "low2high": transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride) else: transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" transformer_in_channels = [v.channels for k, v in transformer_input_shape] self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales self.total_num_feature_levels = total_num_feature_levels self.common_stride = common_stride self.transformer_num_feature_levels = len(self.transformer_in_features) self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels)) self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1 if self.transformer_num_feature_levels > 1: input_proj_list = [] for in_channels in transformer_in_channels[::-1]: input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )) # input projectino for downsample in_channels = max(transformer_in_channels) for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2 input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, conv_dim), )) in_channels = conv_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )]) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, num_feature_levels=self.total_num_feature_levels, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) self.mask_dim = mask_dim # use 1x1 conv instead self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=1, stride=1, padding=0, ) weight_init.c2_xavier_fill(self.mask_features) # extra fpn levels stride = min(self.transformer_feature_strides) self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1) lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] # deformable transformer encoder ret[ "transformer_enc_layers" ] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_in_features"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES'] # ['res3', 'res4', 'res5'] ret["common_stride"] = enc_cfg['COMMON_STRIDE'] ret["total_num_feature_levels"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_feature_levels"] = enc_cfg['NUM_FEATURE_LEVELS'] ret["feature_order"] = enc_cfg['FEATURE_ORDER'] ret["use_ckpt"] = enc_cfg.get('USE_CKPT', False) return ret @autocast(enabled=False) def forward_features(self, features, masks): """ :param features: multi-scale features from the backbone :param masks: image mask :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask """ # backbone features srcs = [] pos = [] # additional downsampled features srcsl = [] posl = [] if self.total_num_feature_levels > self.transformer_num_feature_levels: smallest_feat = features[self.transformer_in_features[self.low_resolution_index]] _len_srcs = self.transformer_num_feature_levels for l in range(_len_srcs, self.total_num_feature_levels): if l == _len_srcs: src = self.input_proj[l](smallest_feat) else: src = self.input_proj[l](srcsl[-1]) srcsl.append(src) posl.append(self.pe_layer(src)) srcsl = srcsl[::-1] # Reverse feature maps for idx, f in enumerate(self.transformer_in_features[::-1]): x = features[f] # deformable detr does not support half precision srcs.append(self.input_proj[idx](x)) pos.append(self.pe_layer(x)) srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs) pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos) if self.feature_order != 'low2high': srcs = srcsl pos = posl y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt) bs = y.shape[0] split_size_or_sections = [None] * self.total_num_feature_levels for i in range(self.total_num_feature_levels): if i < self.total_num_feature_levels - 1: split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] else: split_size_or_sections[i] = y.shape[1] - level_start_index[i] y = torch.split(y, split_size_or_sections, dim=1) out = [] multi_scale_features = [] num_cur_levels = 0 for i, z in enumerate(y): out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) convert = False if out[0].dtype == torch.bfloat16: out = [out_.float() for out_ in out] convert = True for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) if convert: y = y.bfloat16() y=cur_fpn + y y = output_conv(y) out.append(y) if convert: out = [out_.bfloat16() for out_ in out] for o in out: if num_cur_levels < self.total_num_feature_levels: multi_scale_features.append(o) num_cur_levels += 1 return self.mask_features(out[-1]), out[0], multi_scale_features @register_encoder def get_maskdino_encoder_deform(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = OpenSeeDEncoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/openseed/body/encoder/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: llava/model/openseed/body/encoder/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable try: import MultiScaleDeformableAttention as MSDA except ModuleNotFoundError as e: info_string = ( "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" "\t`cd mask2former/modeling/pixel_decoder/ops`\n" "\t`sh make.sh`\n" ) raise ModuleNotFoundError(info_string) class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: llava/model/openseed/body/encoder/ops/make.sh ================================================ #!/usr/bin/env bash # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR python setup.py build install --user ================================================ FILE: llava/model/openseed/body/encoder/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn import MSDeformAttn ================================================ FILE: llava/model/openseed/body/encoder/ops/modules/ms_deform_attn.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 128 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) # try: # print(value.dtype) convert=False if value.dtype== torch.bfloat16: value = value.float() attention_weights = attention_weights.float() sampling_locations = sampling_locations.float() convert=True output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) if convert: output = output.bfloat16() # except: # # CPU # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) # # For FLOPs calculation only # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output ================================================ FILE: llava/model/openseed/body/encoder/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] # Force cuda since torch ask for a device, not if cuda is in fact available. if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: if CUDA_HOME is None: raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') else: raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/openseed/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: llava/model/openseed/body/encoder/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: llava/model/openseed/body/encoder/registry.py ================================================ _model_entrypoints = {} def register_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/encoder/transformer_encoder_fpn.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast import fvcore.nn.weight_init as weight_init from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm from .registry import register_encoder from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn from ...modules import PositionEmbeddingSine from ...utils import configurable # This is a modified FPN decoder. class BasePixelDecoder(nn.Module): def __init__( self, input_shape: Dict[str, ShapeSpec], *, conv_dim: int, mask_dim: int, mask_on: bool, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_channels = [v.channels for k, v in input_shape] lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(feature_channels): if idx == len(self.in_features) - 1: output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( in_channels, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(None) output_convs.append(output_conv) else: lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.mask_on = mask_on if self.mask_on: self.mask_dim = mask_dim self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=3, stride=1, padding=1, ) weight_init.c2_xavier_fill(self.mask_features) self.maskformer_num_feature_levels = 3 # always use 3 scales @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: y = output_conv(x) else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, None, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) class TransformerEncoderOnly(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) if mask is not None: mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) return memory.permute(1, 2, 0).view(bs, c, h, w) # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map. class TransformerEncoderPixelDecoder(BasePixelDecoder): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, transformer_pre_norm: bool, conv_dim: int, mask_dim: int, mask_on: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers transformer_pre_norm: whether to use pre-layernorm or not conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on) input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] in_channels = feature_channels[len(self.in_features) - 1] self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) self.transformer = TransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, normalize_before=transformer_pre_norm, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # update layer use_bias = norm == "" output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) delattr(self, "layer_{}".format(len(self.in_features))) self.add_module("layer_{}".format(len(self.in_features)), output_conv) self.output_convs[0] = output_conv @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = super().from_config(cfg, input_shape) ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_pre_norm"] = dec_cfg['PRE_NORM'] ret['mask_on'] = cfg['MODEL']['DECODER']['MASK'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: transformer = self.input_proj(x) pos = self.pe_layer(x) transformer = self.transformer(transformer, None, pos) y = output_conv(transformer) # save intermediate feature as input to Transformer decoder transformer_encoder_features = transformer else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, transformer_encoder_features, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) @register_encoder def get_transformer_encoder_fpn(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = TransformerEncoderPixelDecoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/openseed/body/openseed_head.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. # ------------------------------------------------------------------------------ import logging from typing import Callable, Dict, List, Optional, Tuple, Union from torch import nn from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_body from .encoder import build_encoder from .decoder import build_decoder from ..utils import configurable class OpenSeeDHead(nn.Module): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, transformer_predictor: nn.Module, ): """ Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. transformer_predictor: the transformer decoder that makes prediction transformer_in_feature: input feature name to the transformer_predictor """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = 4 self.loss_weight = loss_weight self.pixel_decoder = pixel_decoder self.predictor = transformer_predictor self.num_classes = num_classes @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] return { "input_shape": { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] }, "ignore_value": enc_cfg['IGNORE_VALUE'], "num_classes": enc_cfg.get('NUM_CLASSES', None), "pixel_decoder": build_encoder(cfg, input_shape), "loss_weight": enc_cfg['LOSS_WEIGHT'], "transformer_predictor": build_decoder( cfg, transformer_predictor_in_channels, mask_classification=True, extra=extra, ), } def forward(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, task='seg', extra={},default_text_embeddings=None): mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra,default_text_embeddings=default_text_embeddings) return predictions @register_body def get_maskdino_head(cfg, input_shape, lang_encoder, extra): return OpenSeeDHead(cfg, input_shape, lang_encoder, extra) ================================================ FILE: llava/model/openseed/body/registry.py ================================================ _model_entrypoints = {} def register_body(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/body/transformer_blocks.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py """ Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ import copy from typing import List, Optional import torch import torch.nn.functional as F from torch import Tensor, nn class Transformer(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, ) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) if mask is not None: mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder( tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed ) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos ) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") ================================================ FILE: llava/model/openseed/language/LangEncoder/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .build import build_lang_encoder from .build import build_tokenizer from .transformer import * ================================================ FILE: llava/model/openseed/language/LangEncoder/build.py ================================================ import os from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers import AutoTokenizer from .registry import lang_encoders from .registry import is_lang_encoder def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): model_name = config_encoder['NAME'] if not is_lang_encoder(model_name): raise ValueError(f'Unkown model: {model_name}') return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) def build_tokenizer(config_encoder): tokenizer = None os.environ['TOKENIZERS_PARALLELISM'] = 'true' if config_encoder['TOKENIZER'] == 'clip': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) elif config_encoder['TOKENIZER'] == 'clip-fast': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) else: tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) return tokenizer ================================================ FILE: llava/model/openseed/language/LangEncoder/registry.py ================================================ _lang_encoders = {} def register_lang_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _lang_encoders[model_name] = fn return fn def lang_encoders(model_name): return _lang_encoders[model_name] def is_lang_encoder(model_name): return model_name in _lang_encoders ================================================ FILE: llava/model/openseed/language/LangEncoder/transformer.py ================================================ from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .registry import register_lang_encoder from detectron2.utils.comm import is_main_process from utils.model import register_norm_module logger = logging.getLogger(__name__) @register_norm_module class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn( x, x, x, key_padding_mask=key_padding_mask, need_weights=False, attn_mask=self.attn_mask )[0] def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0, autogressive: bool =True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers self.autogressive = autogressive attn_mask = self.build_attention_mask() if autogressive else None dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.ModuleList( [ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) @property def dim_out(self): return self.width def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): if is_main_process(): logger.info('=> init weight of Linear/Conv2d from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: if is_main_process(): logger.info('=> init bias of Linear/Conv2d to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x pretrained_dict = { stripped_key(k): v for k, v in pretrained_dict.items() if stripped_key(k) in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if verbose: logger.info(f'=> init {k} from {pretrained}') if 'positional_embedding' in k and v.size() != model_dict[k].size(): positional_embedding_pretrained = v positional_embedding_current = model_dict[k] L1, nH1 = positional_embedding_pretrained.size() L2, nH2 = positional_embedding_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) posemb = positional_embedding_pretrained.float() posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) v = posemb_grid need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, input_ids, attention_mask=None): key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None # key_padding_mask = (input_ids == 0) if not self.autogressive else None x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND for block in self.resblocks: x = block(x, key_padding_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) return {'last_hidden_state': x} @register_lang_encoder def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): transformer = Transformer( context_length=config_encoder['CONTEXT_LENGTH'], vocab_size=tokenizer.vocab_size, width=config_encoder['WIDTH'], layers=config_encoder['LAYERS'], heads=config_encoder['HEADS'], autogressive=config_encoder.get('AUTOGRESSIVE', True) ) if config_encoder.get('LOAD_PRETRAINED', False): transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*'])) return transformer ================================================ FILE: llava/model/openseed/language/__init__.py ================================================ # from .vlpencoder import * # from .encoder import * # # from .loss import * # from .build import build_language_encoder ================================================ FILE: llava/model/openseed/language/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_language_encoder(config, **kwargs): model_name = config['MODEL']['TEXT']['ARCH'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/openseed/language/encoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } # @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) # @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/openseed/language/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/openseed/language/vlpencoder.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, queue_operator, ): super().__init__() # seg self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) # captioning & retrieval for key, value in queue_operator.items(): self.register_buffer(key, value) @classmethod def from_config(cls, cfg): # build up text encoder for seg tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) # tested not working better queue_operator = {} return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, "queue_operator": queue_operator, } def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] if prompt: for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) else: clss_embeddings.append(extract_mean_emb(class_names)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) def get_text_token_embeddings(self, txts, name='default', token=False, norm=False): if not token: tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens = {key: value.cuda() for key, value in tokens.items()} else: tokens = txts token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm) ret = {"tokens": tokens, "token_emb": token_emb, "class_emb": class_emb,} setattr(self, '{}_token_embeddings'.format(name), ret) return ret def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def forward_language_token(self, texts, norm=False): x = self.lang_encoder(*texts) token_x = x['last_hidden_state'] if self.tokenizer_type == 'clip': class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)] else: class_x = token_x[:, 0] class_x = class_x @ self.lang_proj token_x = token_x @ self.lang_proj if norm: class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7) token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7) return token_x, class_x def compute_similarity(self, v_emb, name='default', fake=False): if fake: return None v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/openseed/modules/__init__.py ================================================ from .point_features import * from .position_encoding import * from .postprocessing import * from .attention import * from .matcher import * from .criterion import * ================================================ FILE: llava/model/openseed/modules/attention.py ================================================ import warnings from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter from torch.overrides import has_torch_function, handle_torch_function from torch.nn.functional import pad, linear, softmax, dropout def multi_head_attention_forward( query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None, k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, tens_ops, query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 if not use_separate_proj_weight: if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)): # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) if key is None: assert value is None k = None v = None else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] k, v = linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] k = linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] v = linear(value, _w, _b) else: q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) len1, len2 = q_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == query.size(-1) k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) len1, len2 = k_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == key.size(-1) v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) len1, len2 = v_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == value.size(-1) if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) v = linear(value, v_proj_weight_non_opt, in_proj_bias) q = q * scaling if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert bias_k is None assert bias_v is None q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: # assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if add_zero_attn: src_len += 1 k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1), float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(nn.Linear): bias: Tensor # type: ignore def __init__(self, in_features: int, out_features: int) -> None: super().__init__(in_features, out_features, bias=True) # type: ignore class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = _LinearWithBias(embed_dim, embed_dim) if add_bias_kv: self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(MultiheadAttention, self).__setstate__(state) def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not self._qkv_same_embed_dim: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) else: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask) ================================================ FILE: llava/model/openseed/modules/criterion.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ """ MaskFormer criterion. """ import logging import torch import torch.nn.functional as F from torch import nn from timm.loss import SoftTargetCrossEntropy from detectron2.utils.comm import get_world_size from detectron2.projects.point_rend.point_features import ( get_uncertain_point_coords_with_randomness, point_sample, ) from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis from ..utils import box_ops import random def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(-1).mean(-1).sum()*10. / num_boxes def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") return loss.mean(1).sum() / num_masks sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) class SetCriterion(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses, num_points, oversample_ratio, importance_sample_ratio, grounding_weight, dn="no",dn_losses=[], panoptic_on=False, semantic_ce_loss=False): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.top_x_layers = top_x_layers self.losses = losses self.dn = dn self.dn_losses = dn_losses empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = num_points self.oversample_ratio = oversample_ratio self.importance_sample_ratio = importance_sample_ratio self.focal_alpha = 0.25 self.panoptic_on = panoptic_on self.semantic_ce_loss = semantic_ce_loss self.grounding_weight = grounding_weight self.conversation=False def loss_labels_ce(self, outputs, targets, indices, num_masks, layer_id=None, extra=None): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert "pred_logits" in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs["pred_logits"].type(self.empty_weight.dtype) idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_mask_cls_0": loss_ce} return losses def loss_labels_masked(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert 'pred_logits' in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits[idx], target_classes_onehot[idx], num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_mask_cls_0': loss_ce} return losses def loss_labels(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ if layer_id > self.top_x_layers['mask']: return {"loss_mask_cls_0": 0} assert 'pred_logits' in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) # target_classes = torch.full(src_logits.shape[:2], self.num_classes, # dtype=torch.int64, device=src_logits.device) # target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) for batch_id,indices_ in enumerate(indices): for src,tgt in zip(*indices_): gt_lbs=targets[batch_id]['labels'][tgt] target_classes_onehot[batch_id,src,gt_lbs]=1 # target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) # target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_mask_cls_0': loss_ce} return losses def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ if layer_id >= self.top_x_layers['box']: return {"loss_bbox_0": 0, "loss_giou_0": 0} assert 'pred_boxes' in outputs if indices is None or len(targets) == 0: loss = outputs['pred_boxes'].sum() * 0.0 losses = {"loss_bbox_0": loss, "loss_giou_0": loss} return losses idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou_0'] = loss_giou.sum() / num_boxes return losses def loss_boxes_panoptic(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ if layer_id >= self.top_x_layers['box']: return {"loss_bbox_0": 0, "loss_giou_0": 0} assert 'pred_boxes' in outputs if indices is None or len(targets) == 0: loss = outputs['pred_boxes'].sum() * 0.0 losses = {"loss_bbox_0": loss, "loss_giou_0": loss} return losses idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_labels = torch.cat([t['labels'][i] for t, (_, i) in zip(targets, indices)], dim=0) isthing=target_labels<80 target_boxes=target_boxes[isthing] src_boxes=src_boxes[isthing] loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou_0'] = loss_giou.sum() / num_boxes return losses def loss_masks(self, outputs, targets, indices, num_masks, layer_id=None, extra=None): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ if layer_id >= self.top_x_layers['mask']: return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0} assert "pred_masks" in outputs if indices is None or len(targets) == 0: loss = outputs['pred_masks'].sum() * 0.0 losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss} return losses src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] src_masks = src_masks[src_idx] masks = [t["masks"] for t in targets] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(src_masks) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates :) # N x 1 x H x W src_masks = src_masks[:, None] target_masks = target_masks[:, None] with torch.no_grad(): # sample point_coords point_coords = get_uncertain_point_coords_with_randomness( src_masks.float(), lambda logits: calculate_uncertainty(logits.float()), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) # get gt labels point_labels = point_sample( target_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) point_logits = point_sample( src_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) losses = { "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks), } del src_masks del target_masks return losses def prep_for_dn(self,mask_dict): output_known_lbs_bboxes = mask_dict['output_known_lbs_bboxes'] known_indice = mask_dict['known_indice'] scalar,pad_size=mask_dict['scalar'],mask_dict['pad_size'] assert pad_size % scalar==0 single_pad=pad_size//scalar num_tgt = known_indice.numel() return output_known_lbs_bboxes,num_tgt,single_pad,scalar def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_masks=None, layer_id=None, extra=None): loss_map = { 'labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels, 'labels_dn': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked, 'dn_labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels_masked, 'masks': self.loss_masks, 'boxes': self.loss_boxes_panoptic if self.panoptic_on else self.loss_boxes, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_masks, layer_id=layer_id, extra=extra) def forward(self, outputs, targets, mask_dict=None, extra=None, task='seg'): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # TODO: use different matching and loss weight when only detection outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} match_cost = ["cls", "box", "mask"] if task == 'det' or task == 'seg_from_teacher': match_cost = ["cls", "box"] # Retrieve the matching between the outputs of the last layer and the targets if self.dn != "no" and mask_dict is not None: output_known_lbs_bboxes,num_tgt,single_pad,scalar = self.prep_for_dn(mask_dict) exc_idx = [] for i in range(len(targets)): if len(targets[i]['labels']) > 0: t = torch.arange(0, len(targets[i]['labels'])).long().cuda() t = t.unsqueeze(0).repeat(scalar, 1) tgt_idx = t.flatten() output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t output_idx = output_idx.flatten() else: output_idx = tgt_idx = torch.tensor([]).long().cuda() exc_idx.append((output_idx, tgt_idx)) extra=dict() # if task == "seg": # extra['split_pano']={'n_q_th':300} # # else: # extra['split_pano'] = None indices = self.matcher(outputs_without_aux, targets, match_cost, extra=extra) # Compute the average number of target boxes accross all nodes, for normalization purposes num_masks = sum(len(t["labels"]) for t in targets) num_masks = torch.as_tensor( [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device ) if is_dist_avail_and_initialized() and not self.conversation: torch.distributed.all_reduce(num_masks) num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() else: num_masks = torch.clamp(num_masks, min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, layer_id=0, extra=extra)) if self.dn != "no" and mask_dict is not None: l_dict={} for loss in self.dn_losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only if loss == 'labels': loss='labels_dn' l_dict.update(self.get_loss(loss, output_known_lbs_bboxes, targets, exc_idx, num_masks*scalar, layer_id=0)) l_dict = {k + f'_dn': v for k, v in l_dict.items()} losses.update(l_dict) elif self.dn != "no": l_dict = dict() l_dict['loss_bbox_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_giou_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_cls_0_dn'] = torch.as_tensor(0.).to('cuda') if task != 'det' and 'masks' in self.dn_losses: l_dict['loss_mask_bce_0_dn'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_dice_0_dn'] = torch.as_tensor(0.).to('cuda') losses.update(l_dict) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets, match_cost, extra=extra) for loss in self.losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, layer_id=(i+1), extra=extra) l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()} losses.update(l_dict) if 'interm_outputs' in outputs: start = 0 else: start = 1 if i>=start: if self.dn != "no" and mask_dict is not None: out_=output_known_lbs_bboxes['aux_outputs'][i] l_dict = {} for loss in self.dn_losses: if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only if loss == 'labels': loss = 'labels_dn' l_dict.update( self.get_loss(loss, out_, targets, exc_idx, num_masks * scalar, layer_id=(i+1), extra=extra)) l_dict = {k.replace('_0', f"_{i+1}_dn"): v for k, v in l_dict.items()} losses.update(l_dict) elif self.dn != "no": l_dict = dict() l_dict[f'loss_bbox_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_giou_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_mask_cls_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') if self.dn == "seg" and task != 'det' and 'masks' in self.dn_losses: l_dict[f'loss_mask_bce_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') l_dict[f'loss_mask_dice_{i+1}_dn'] = torch.as_tensor(0.).to('cuda') losses.update(l_dict) # interm_outputs loss if 'interm_outputs' in outputs: interm_outputs = outputs['interm_outputs'] indices = self.matcher(interm_outputs, targets, match_cost, extra=extra) full_set = ['labels', 'masks', 'boxes'] for loss in list(set(self.losses) and set(full_set)): if task == 'det' and loss == 'masks': continue # not compute mask loss for detection data only l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_masks, layer_id=-1, extra=extra) l_dict = {k + f'_interm': v for k, v in l_dict.items()} losses.update(l_dict) return losses def __repr__(self): head = "Criterion " + self.__class__.__name__ body = [ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), "losses: {}".format(self.losses), "weight_dict: {}".format(self.weight_dict), "num_classes: {}".format(self.num_classes), "eos_coef: {}".format(self.eos_coef), "num_points: {}".format(self.num_points), "oversample_ratio: {}".format(self.oversample_ratio), "importance_sample_ratio: {}".format(self.importance_sample_ratio), ] _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/matcher.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from MaskDINO https://github.com/IDEA-Research/MaskDINO by Hao Zhang and Feng Li. # ------------------------------------------------------------------------ """ Modules to compute the matching cost and solve the corresponding LSAP. """ import torch import torch.nn.functional as F import numpy as np from scipy.optimize import linear_sum_assignment from torch import nn from torch.cuda.amp import autocast from detectron2.projects.point_rend.point_features import point_sample from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss batch_dice_loss_jit = torch.jit.script( batch_dice_loss ) # type: torch.jit.ScriptModule def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), reduction="none" ) loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( "nc,mc->nm", neg, (1 - targets) ) return loss / hw batch_sigmoid_ce_loss_jit = torch.jit.script( batch_sigmoid_ce_loss ) # type: torch.jit.ScriptModule class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, cost_box: float = 0, cost_giou: float = 0, panoptic_on: bool = False): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice self.cost_box = cost_box self.cost_giou = cost_giou self.panoptic_on = panoptic_on assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" self.num_points = num_points @torch.no_grad() def memory_efficient_forward(self, outputs, targets, cost=["cls", "box", "mask"],split_pano=None): """More memory-friendly matching. Change cost to compute only certain loss in matching""" bs, num_queries = outputs["pred_logits"].shape[:2] indices = [] # Iterate through batch size for b in range(bs): out_bbox = outputs["pred_boxes"][b].float() if 'box' in cost: tgt_bbox=targets[b]["boxes"] cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) else: cost_bbox = torch.tensor(0).to(out_bbox) cost_giou = torch.tensor(0).to(out_bbox) out_prob = outputs["pred_logits"][b].sigmoid().float() # [num_queries, num_classes] tgt_ids = targets[b]["labels"] cost_class=torch.zeros_like(cost_bbox) alpha = 0.25 gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-6).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-6).log()) for idx, tgt_ids_ in enumerate(tgt_ids): if len(tgt_ids_) == 0: continue cost_class_tmp = pos_cost_class[:, tgt_ids_] - neg_cost_class[:, tgt_ids_] cost_class_tmp = cost_class_tmp.mean(dim=1, keepdim=False) cost_class[:, idx] = cost_class_tmp # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. # cost_class = -out_prob[:, tgt_ids] if 'mask' in cost: out_mask = outputs["pred_masks"][b].float() # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["masks"].to(out_mask).float() out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype) # get gt labels tgt_mask = point_sample( tgt_mask.float(), point_coords.repeat(tgt_mask.shape[0], 1, 1).float(), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask.float(), point_coords.repeat(out_mask.shape[0], 1, 1).float(), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) else: cost_mask = torch.tensor(0).to(out_bbox) cost_dice = torch.tensor(0).to(out_bbox) # Final cost matrix if self.panoptic_on: isthing = tgt_ids<80 cost_bbox[:, ~isthing] = cost_bbox[:, isthing].mean() cost_giou[:, ~isthing] = cost_giou[:, isthing].mean() cost_bbox[cost_bbox.isnan()] = 0.0 cost_giou[cost_giou.isnan()] = 0.0 C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + self.cost_box*cost_bbox + self.cost_giou*cost_giou ) C = C.reshape(num_queries, -1).cpu() # if split_pano is not None: # n_q_th=split_pano['n_q_th'] # th_mask=tgt_ids<80 # There are 80 COCO thing classes (should be modified when trained with other panoptic datasets) # C[n_q_th:,th_mask]=1e4 # C[:n_q_th,~th_mask]=1e4 indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def forward(self, outputs, targets, cost=["cls", "box", "mask"], mode='default', extra={}): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ if mode == 'default': if extra is not None: split_pano = extra.get('split_pano', None) else: split_pano=None return self.memory_efficient_forward(outputs, targets, cost,split_pano=split_pano) else: assert False, "Mode {} is not supported.".format(mode) def __repr__(self, _repr_indent=4): head = "Matcher " + self.__class__.__name__ body = [ "cost_class: {}".format(self.cost_class), "cost_mask: {}".format(self.cost_mask), "cost_dice: {}".format(self.cost_dice), ] lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/point_features.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.layers import cat, shapes_to_tensor from detectron2.structures import BitMasks, Boxes # from ..layers import cat, shapes_to_tensor # from ..structures import BitMasks, Boxes """ Shape shorthand in this module: N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the number of images for semantic segmenation. R: number of ROIs, combined over all images, in the minibatch P: number of points """ def point_sample(input, point_coords, **kwargs): """ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside [0, 1] x [0, 1] square. Args: input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains [0, 1] x [0, 1] normalized point coordinates. Returns: output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains features for points in `point_coords`. The features are obtained via bilinear interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. """ add_dim = False if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output def generate_regular_grid_point_coords(R, side_size, device): """ Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. Args: R (int): The number of grids to sample, one for each region. side_size (int): The side size of the regular grid. device (torch.device): Desired device of returned tensor. Returns: (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates for the regular grids. """ aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) return r.view(1, -1, 2).expand(R, -1, -1) def get_uncertain_point_coords_with_randomness( coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio ): """ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. See PointRend paper for details. Args: coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for class-specific or class-agnostic prediction. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that contains logit predictions for P points and returns their uncertainties as a Tensor of shape (N, 1, P). num_points (int): The number of points P to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. """ assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype) point_logits = point_sample(coarse_logits, point_coords, align_corners=False) # It is crucial to calculate uncertainty based on the sampled prediction value for the points. # Calculating uncertainties of the coarse predictions first and sampling them for points leads # to incorrect results. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. # However, if we calculate uncertainties for the coarse predictions first, # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( num_boxes, num_uncertain_points, 2 ) if num_random_points > 0: point_coords = cat( [ point_coords, torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty values for a set of points on a regular H x W grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W grid. """ R, _, H, W = uncertainty_map.shape h_step = 1.0 / float(H) w_step = 1.0 / float(W) num_points = min(H * W, num_points) point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step return point_indices, point_coords def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): """ Get features from feature maps in `features_list` that correspond to specific point coordinates inside each bounding box from `boxes`. Args: features_list (list[Tensor]): A list of feature map tensors to get features from. feature_scales (list[float]): A list of scales for tensors in `features_list`. boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all together. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled from all features maps in feature_list for P sampled points for all R boxes in `boxes`. point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level coordinates of P points. """ cat_boxes = Boxes.cat(boxes) num_boxes = [b.tensor.size(0) for b in boxes] point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) point_features = [] for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): point_features_per_image = [] for idx_feature, feature_map in enumerate(features_list): h, w = feature_map.shape[-2:] scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature] point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device) point_features_per_image.append( point_sample( feature_map[idx_img].unsqueeze(0), point_coords_scaled.unsqueeze(0), align_corners=False, ) .squeeze(0) .transpose(1, 0) ) point_features.append(cat(point_features_per_image, dim=1)) return cat(point_features, dim=0), point_coords_wrt_image def get_point_coords_wrt_image(boxes_coords, point_coords): """ Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. Args: boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. coordinates. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-normalized coordinates of P sampled points. """ with torch.no_grad(): point_coords_wrt_image = point_coords.clone() point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( boxes_coords[:, None, 2] - boxes_coords[:, None, 0] ) point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( boxes_coords[:, None, 3] - boxes_coords[:, None, 1] ) point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] return point_coords_wrt_image def sample_point_labels(instances, point_coords): """ Sample point labels from ground truth mask given point_coords. Args: instances (list[Instances]): A list of N Instances, where N is the number of images in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. The ground-truth gt_masks in each instance will be used to compute labels. points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of instances and P is the number of points for each instance. The coordinates are in the absolute image pixel coordinate space, i.e. [0, H] x [0, W]. Returns: Tensor: A tensor of shape (R, P) that contains the labels of P sampled points. """ with torch.no_grad(): gt_mask_logits = [] point_coords_splits = torch.split( point_coords, [len(instances_per_image) for instances_per_image in instances] ) for i, instances_per_image in enumerate(instances): if len(instances_per_image) == 0: continue assert isinstance( instances_per_image.gt_masks, BitMasks ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." gt_bit_masks = instances_per_image.gt_masks.tensor h, w = instances_per_image.gt_masks.image_size scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) points_coord_grid_sample_format = point_coords_splits[i] / scale gt_mask_logits.append( point_sample( gt_bit_masks.to(torch.float32).unsqueeze(1), points_coord_grid_sample_format, align_corners=False, ).squeeze(1) ) point_labels = cat(gt_mask_logits) return point_labels ================================================ FILE: llava/model/openseed/modules/position_encoding.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py """ Various positional encodings for the transformer. """ import math import torch from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x, mask=None): if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=x.dtype) x_embed = not_mask.cumsum(2, dtype=x.dtype) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/openseed/modules/postprocessing.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.structures import Instances, ROIMasks # perhaps should rename to "resize_instance" def detector_postprocess( results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 ): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ if isinstance(output_width, torch.Tensor): # This shape might (but not necessarily) be tensors during tracing. # Converts integer tensors to float temporaries to ensure true # division is performed when computing scale_x and scale_y. output_width_tmp = output_width.float() output_height_tmp = output_height.float() new_size = torch.stack([output_height, output_width]) else: new_size = (output_height, output_width) output_width_tmp = output_width output_height_tmp = output_height scale_x, scale_y = ( output_width_tmp / results.image_size[1], output_height_tmp / results.image_size[0], ) results = Instances(new_size, **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes else: output_boxes = None assert output_boxes is not None, "Predictions must contain boxes!" output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): if isinstance(results.pred_masks, ROIMasks): roi_masks = results.pred_masks else: # pred_masks is a tensor of shape (N, 1, M, M) roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) results.pred_masks = roi_masks.to_bitmasks( results.pred_boxes, output_height, output_width, mask_threshold ).tensor # TODO return ROIMasks/BitMask object in the future if results.has("pred_keypoints"): results.pred_keypoints[:, :, 0] *= scale_x results.pred_keypoints[:, :, 1] *= scale_y return results def bbox_postprocess(result, input_size, img_size, output_height, output_width): """ result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] """ if result is None: return None scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) result = result.sigmoid() * scale x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 h,w = img_size x1 = x1.clamp(min=0, max=w) y1 = y1.clamp(min=0, max=h) x2 = x2.clamp(min=0, max=w) y2 = y2.clamp(min=0, max=h) box = torch.stack([x1,y1,x2,y2]).permute(1,0) scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) box = box*scale return box def sem_seg_postprocess(result, img_size, output_height, output_width): """ Return semantic segmentation predictions in the original resolution. The input images are often resized when entering semantic segmentor. Moreover, in same cases, they also padded inside segmentor to be divisible by maximum network stride. As a result, we often need the predictions of the segmentor in a different resolution from its inputs. Args: result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), where C is the number of classes, and H, W are the height and width of the prediction. img_size (tuple): image size that segmentor is taking as input. output_height, output_width: the desired output resolution. Returns: semantic segmentation prediction (Tensor): A tensor of the shape (C, output_height, output_width) that contains per-pixel soft predictions. """ result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) result = F.interpolate( result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True )[0] return result ================================================ FILE: llava/model/openseed/utils/__init__.py ================================================ from .config import * from .misc import * ================================================ FILE: llava/model/openseed/utils/box_ops.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for bounding box manipulation and GIoU. """ import torch from torchvision.ops.boxes import box_area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) def box_xywh_to_xyxy(x): x0, y0, x1, y1 = x.unbind(-1) b = [x0, y0, (x0 + x1), (y0 + y1)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / (union+1e-6) return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / (area+1e-6) def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) ================================================ FILE: llava/model/openseed/utils/config.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import functools import inspect def configurable(init_func=None, *, from_config=None): """ Decorate a function or a class's __init__ method so that it can be called with a :class:`CfgNode` object using a :func:`from_config` function that translates :class:`CfgNode` to arguments. Examples: :: # Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite Args: init_func (callable): a class's ``__init__`` method in usage 1. The class must have a ``from_config`` classmethod which takes `cfg` as the first argument. from_config (callable): the from_config function in usage 2. It must take `cfg` as its first argument. """ if init_func is not None: assert ( inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__" ), "Incorrect use of @configurable. Check API documentation for examples." @functools.wraps(init_func) def wrapped(self, *args, **kwargs): try: from_config_func = type(self).from_config except AttributeError as e: raise AttributeError( "Class with @configurable must have a 'from_config' classmethod." ) from e if not inspect.ismethod(from_config_func): raise TypeError("Class with @configurable must have a 'from_config' classmethod.") if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) init_func(self, **explicit_args) else: init_func(self, *args, **kwargs) return wrapped else: if from_config is None: return configurable # @configurable() is made equivalent to @configurable assert inspect.isfunction( from_config ), "from_config argument of configurable must be a function!" def wrapper(orig_func): @functools.wraps(orig_func) def wrapped(*args, **kwargs): if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config, *args, **kwargs) return orig_func(**explicit_args) else: return orig_func(*args, **kwargs) wrapped.from_config = from_config return wrapped return wrapper def _called_with_cfg(*args, **kwargs): """ Returns: bool: whether the arguments contain CfgNode and should be considered forwarded to from_config. """ from omegaconf import DictConfig, OmegaConf, ListConfig # from detectron2.config import LazyConfig if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))): return True if isinstance(kwargs.pop("cfg", None), (dict)): return True # `from_config`'s first argument is forced to be "cfg". # So the above check covers all cases. return False def _get_args_from_config(from_config_func, *args, **kwargs): """ Use `from_config` to obtain explicit arguments. Returns: dict: arguments to be used for cls.__init__ """ signature = inspect.signature(from_config_func) if list(signature.parameters.keys())[0] != "cfg": if inspect.isfunction(from_config_func): name = from_config_func.__name__ else: name = f"{from_config_func.__self__}.from_config" raise TypeError(f"{name} must take 'cfg' as the first argument!") support_var_arg = any( param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] for param in signature.parameters.values() ) if support_var_arg: # forward all arguments to from_config, if from_config accepts them ret = from_config_func(*args, **kwargs) else: # forward supported arguments to from_config supported_arg_names = set(signature.parameters.keys()) extra_kwargs = {} for name in list(kwargs.keys()): if name not in supported_arg_names: extra_kwargs[name] = kwargs.pop(name) ret = from_config_func(*args, **kwargs) # forward the other arguments to __init__ ret.update(extra_kwargs) return ret ================================================ FILE: llava/model/openseed/utils/misc.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ from typing import List, Optional import torch import torch.distributed as dist import torchvision from torch import Tensor # from utils.constants import * def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False elif tensor_list[0].ndim == 2: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(txt.shape) for txt in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, l = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, l), dtype=torch.bool, device=device) for txt, pad_txt, m in zip(tensor_list, tensor, mask): pad_txt[: txt.shape[0], : txt.shape[1]] = txt m[: txt.shape[1]] = False else: raise ValueError("not supported") return NestedTensor(tensor, mask) def _collate_and_pad_divisibility(tensor_list: list, div=32): max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) c,h,w = max_size pad_h = (div - h % div) if h % div != 0 else 0 pad_w = (div - w % div) if w % div != 0 else 0 max_size = (c,h+pad_h,w+pad_w) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) return padded_imgs # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True # def get_class_names(name, background=True): # if name is None: # return None # if 'refcoco' in name: # class_names = ["noun"] # elif 'coco' in name and 'pan' not in name: # class_names = COCO_INSTANCE_CLASSES + ["background"] # elif 'coco' in name: # class_names = COCO_PANOPTIC_CLASSES + ["background"] # elif 'ade20k_full' in name: # class_names = ADE20K_847 + ["background"] # elif 'ade' in name: # class_names = ADE_PANOPTIC_CLASSES + ["background"] # elif 'voc' in name: # class_names = PASCAL_CLASSES + ["background"] # elif 'vlp' in name: # class_names = ["noun"] # elif 'tsv' in name: # class_names = ["noun"] # elif 'phrasecut' in name: # class_names = ["noun"] # elif 'openimage' in name: # class_names = ["noun"] # elif 'imagenet' in name: # class_names = IMAGENET_CLASSES # elif 'context_459' in name: # class_names = PASCAL_CONTEXT_459 + ["background"] # elif 'context_59' in name: # class_names = PASCAL_CONTEXT_59 + ["background"] # elif 'context_33' in name: # class_names = PASCAL_CONTEXT_33 # elif 'sunrgbd_37' in name: # class_names = SUN_RGBD_37 # elif 'scannet_41' in name: # class_names = SCAN_40 # elif 'scannet_38' in name: # class_names = SCAN_37 # elif 'scannet_21' in name: # class_names = SCAN_20 # elif 'object365' in name: # class_names = OBJECT365 # elif 'lvis' in name: # class_names = LVIS_CATEGORIES # elif 'seginw' in name: # class_names = SEGINW_CATEGORIES[name.replace('_train', '').replace('_val', '')] + ["background"] # elif name == 'cityscapes_fine_sem_seg_val': # class_names = CITYSCAPES # elif name == 'cityscapes_fine_instance_seg_val': # class_names = CITYSCAPES_THING + ["background"] # elif name in ['cityscapes_fine_panoptic_val', 'cityscapes_fine_panoptic_train']: # class_names = CITYSCAPES + ["background"] # elif name == 'bdd10k_val_sem_seg': # class_names = BDD_SEM # elif name == 'bdd10k_40_panoptic_val': # class_names = BDD_PANO # else: # assert False, "text dataset name {} is not defined".format(name) # # if background == False and "background" in class_names: # class_names.pop(class_names.index("background")) # # return class_names # TODO: add background to # def get_class_names(name): # if name is None: # return None # elif 'refcoco' in name: # return ["background"] # elif 'coco' in name: # return COCO_PANOPTIC_CLASSES + ["background"] # elif 'ade20k_full' in name: # return ADE20K_847 + ["background"] # elif 'ade' in name: # return ADE_PANOPTIC_CLASSES + ["background"] # elif 'scannet_41' in name: # return SCAN_40 # elif 'scannet_21' in name: # return SCAN_20 # elif 'sun' in name: # return SUN_RGBD_37 # elif name == 'cityscapes_fine_sem_seg_val': # return CITYSCAPES + ["background"] # elif name == 'cityscapes_fine_instance_seg_val': # return CITYSCAPES_THING + ["background"] # elif name in ['cityscapes_fine_panoptic_val']: # return CITYSCAPES + ["background"] # elif name == 'bdd10k_val_sem_seg': # return BDD_SEM + ["background"] # elif name == 'bdd10k_40_panoptic_val': # return BDD_PANO + ["background"] # elif 'vlp' in name: # return ["background"] # else: # assert False, "text dataset name {} is not defined".format(name) ================================================ FILE: llava/model/semsam/BaseModel.py ================================================ import os import logging import torch import torch.nn as nn from utils.model import align_and_update_state_dicts logger = logging.getLogger(__name__) class BaseModel(nn.Module): def __init__(self, opt, module: nn.Module): super(BaseModel, self).__init__() self.opt = opt self.model = module def forward(self, *inputs, **kwargs): outputs = self.model(*inputs, **kwargs) return outputs def save_pretrained(self, save_dir): torch.save(self.model.state_dict(), save_path) def from_pretrained(self, load_dir): state_dict = torch.load(load_dir, map_location='cpu') # import pdb;pdb.set_trace() # import pdb;pdb.set_trace() if 'model' in state_dict: state_dict=state_dict['model'] state_dict={k[6:]:v for k,v in state_dict.items()} # if self.opt['MODEL']['LLAMA'].get('lora_r',0)>0: # new_sd = dict() # for k,v in state_dict.items(): # if k.startswith("llama."): # if k.startswith("llama.base_model."): # new_sd=state_dict # break # new_sd[k.replace("llama.","llama.base_model.model.")]=v # else: # new_sd[k]=v # else: # new_sd = state_dict new_sd = align_and_update_state_dicts(self.model.state_dict(), state_dict) self.model.load_state_dict(new_sd, strict=False) return self ================================================ FILE: llava/model/semsam/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .architectures import build_model ================================================ FILE: llava/model/semsam/architectures/__init__.py ================================================ from .idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1 import * from .build import build_model ================================================ FILE: llava/model/semsam/architectures/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_model(config, **kwargs): model_name = config['MODEL']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/architectures/idino_model_partwhole_all_llm_ref_feats_all_det_pretrainv1.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. from typing import Tuple import torch from torch import nn from torch.nn import functional as F import transformers from .registry import register_model from ..utils import configurable, box_ops from ..backbone import build_backbone, Backbone from ..body import build_openseed_head from ..modules import sem_seg_postprocess, HungarianMatcher from ..modules import SetCriterionLLM as SetCriterion from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.utils.memory import retry_if_cuda_oom from detectron2.data import MetadataCatalog import torch.distributed as dist import random import os import torchvision from PIL import Image def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, # num_masks, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) # only match the lowest loss # loss = loss.view(-1, 3) # loss = loss.min(1)[0] return loss.sum() # return loss def iou_score_loss(inputs, targets): ce_loss = F.mse_loss(inputs, targets, reduction="none") return ce_loss dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, # num_masks, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.mean(1) # loss = loss.view(-1, 3).min(1)[0] return loss.sum() # return loss sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.sum() class SemanticSAM(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, criterion: nn.Module, num_queries: int, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], # inference semantic_on: bool, panoptic_on: bool, instance_on: bool, test_topk_per_image: int, data_loader: str, pano_temp: float, focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, train_dataset_name: str, background: bool, coco_on=True, coco_mask_on=True, o365_on=True, ade_on=True, merge_class=False, sam_on: bool = True, pascal_part_on: bool = True, regenerate_point: bool = False, num_mask_tokens: int = 3, interactive_pretrain=False, match_loss=True, num_vg=2, vis_out="vis/", coco_old=True, clip_on=False, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image semantic_on: bool, whether to output semantic segmentation prediction instance_on: bool, whether to output instance segmentation prediction panoptic_on: bool, whether to output panoptic segmentation prediction test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() self.backbone = backbone self.pano_temp = pano_temp self.sem_seg_head = sem_seg_head self.criterion = criterion self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.object_mask_threshold = object_mask_threshold self.metadata = metadata self.num_vg = num_vg if size_divisibility < 0: size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.semantic_on = semantic_on self.instance_on = instance_on self.panoptic_on = panoptic_on self.test_topk_per_image = test_topk_per_image self.data_loader = data_loader self.focus_on_box = focus_on_box self.transform_eval = transform_eval self.semantic_ce_loss = semantic_ce_loss self.coco_keys = None self.train_class_names = dict() self.train_dataset_name = train_dataset_name self.coco_mask_on = coco_mask_on self.task_switch = {'coco': coco_on, 'o365': o365_on, 'sam': sam_on, 'pascal_part': pascal_part_on, "ade": ade_on} self.interactive_pretrain = interactive_pretrain self.dbg = False self.positive = 0 self.num_objs = 0 self.num_hits = 0 self.num_refer = 0 self.ref_iou = 0.0 self.random_iou = 0.0 self.match_loss = match_loss self.clip_on = clip_on self.num_all_masks = 0. self.coco_old = coco_old self.multimodal_cfg = {'is_multimodal': True, 'image_token_len': 140, 'use_im_start_end': True} self.logit_scale = nn.Parameter(torch.ones([])) self.vis_out = vis_out self.obj_projector = nn.Linear(256, 4096) print("self.task_switch ", self.task_switch) if not self.semantic_on: assert self.sem_seg_postprocess_before_inference self.max_num_instance = 100 self.num_mask_tokens = num_mask_tokens self.regenerate_point = regenerate_point @classmethod def from_config(cls, cfg): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] deep_supervision = dec_cfg['DEEP_SUPERVISION'] no_object_weight = dec_cfg['NO_OBJECT_WEIGHT'] # loss weights iou_weight = dec_cfg['IOU_WEIGHT'] class_weight = dec_cfg['CLASS_WEIGHT'] cost_class_weight = dec_cfg['COST_CLASS_WEIGHT'] cost_dice_weight = dec_cfg['COST_DICE_WEIGHT'] dice_weight = dec_cfg['DICE_WEIGHT'] cost_mask_weight = dec_cfg['COST_MASK_WEIGHT'] mask_weight = dec_cfg['MASK_WEIGHT'] cost_box_weight = dec_cfg['COST_BOX_WEIGHT'] box_weight = dec_cfg['BOX_WEIGHT'] cost_giou_weight = dec_cfg['COST_GIOU_WEIGHT'] giou_weight = dec_cfg['GIOU_WEIGHT'] refer_weight = dec_cfg['REFER_WEIGHT'] fix_backbone = cfg.get('fix_backbone', False) # building matcher matcher = HungarianMatcher( cost_class=cost_class_weight, cost_mask=cost_mask_weight, cost_dice=cost_dice_weight, cost_box=cost_box_weight, cost_giou=cost_giou_weight, num_points=dec_cfg['TRAIN_NUM_POINTS'], ) # MaskDINO losses and weight_dict weight_dict = {"loss_mask_cls_0": class_weight} weight_dict.update({"loss_mask_bce_0": mask_weight, "loss_mask_dice_0": dice_weight}) weight_dict.update({"loss_bbox_0": box_weight, "loss_giou_0": giou_weight}) weight_dict.update({"iou_score_loss_0": iou_weight}) weight_dict.update({"loss_mask_part_cls_0": class_weight}) # two stage is the query selection scheme if dec_cfg['TWO_STAGE']: interm_weight_dict = {} interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) weight_dict.update(interm_weight_dict) # denoising training dn = dec_cfg['DN'] # TODO hack for dn lable loss if dn == "standard": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"}) dn_losses = ["dn_labels", "boxes"] elif dn == "seg": weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) dn_losses = ["masks", "dn_labels", "boxes"] else: dn_losses = [] if deep_supervision: dec_layers = dec_cfg['DEC_LAYERS'] aux_weight_dict = {} for i in range(dec_layers): aux_weight_dict.update({k.replace('_0', '_{}'.format(i + 1)): v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) if dec_cfg['BOX']: losses = ["masks", "labels", "boxes"] else: losses = ["masks", "labels", ] if dec_cfg['PART']: losses.append('labels_part') weight_dict.update({'all': 1.0, 'sam': 1.0, 'pas': 1.0}) # update task switch task_switch = {} task_switch.update({'bbox': dec_cfg.get('DETECTION', True), 'mask': dec_cfg.get('MASK', True)}) weight_multiplier= dec_cfg.get('WEIGHT_MULTIPLIER', 1.0) weight_dict={k:v*weight_multiplier for k,v in weight_dict.items()} # building criterion criterion = SetCriterion( enc_cfg['NUM_CLASSES'], matcher=matcher, weight_dict=weight_dict, # top_x_layers=top_x_layers, eos_coef=no_object_weight, losses=losses, num_points=dec_cfg['TRAIN_NUM_POINTS'], oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'], importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'], # grounding_weight=None, dn=dec_cfg['DN'], dn_losses=dn_losses, panoptic_on=dec_cfg['PANO_BOX_LOSS'], semantic_ce_loss=dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST'][ 'PANOPTIC_ON'], num_mask_tokens=dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3) ) # build model extra = {'task_switch': task_switch} backbone = build_backbone(cfg) if fix_backbone: for name, param in backbone.named_parameters(): param.requires_grad = False # backbone sem_seg_head = build_openseed_head(cfg, backbone.output_shape(), None, extra=extra) if fix_backbone: for name, param in sem_seg_head.named_parameters(): param.requires_grad = False return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion": criterion, "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]), "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], "sem_seg_postprocess_before_inference": ( dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] or dec_cfg['TEST']['PANOPTIC_ON'] or dec_cfg['TEST']['INSTANCE_ON'] ), "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], "pixel_std": cfg['INPUT']['PIXEL_STD'], # inference "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], "instance_on": dec_cfg['TEST']['INSTANCE_ON'], "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'], "data_loader": None, "focus_on_box": cfg['MODEL']['DECODER']['TEST']['TEST_FOUCUS_ON_BOX'], "transform_eval": cfg['MODEL']['DECODER']['TEST']['PANO_TRANSFORM_EVAL'], "pano_temp": cfg['MODEL']['DECODER']['TEST']['PANO_TEMPERATURE'], "semantic_ce_loss": cfg['MODEL']['DECODER']['TEST']['SEMANTIC_ON'] and cfg['MODEL']['DECODER'][ 'SEMANTIC_CE_LOSS'] and not cfg['MODEL']['DECODER']['TEST']['PANOPTIC_ON'], "train_dataset_name": cfg['DATASETS']['TRAIN'], # HACK for only two training set "background": cfg['MODEL'].get('BACKGROUND', True), "coco_on": dec_cfg.get('COCO', True), "coco_mask_on": dec_cfg.get('COCO_MASK', True), "o365_on": dec_cfg.get('O365', True), "sam_on": dec_cfg.get('SAM', True), "pascal_part_on": dec_cfg.get('PASCAL', True), "regenerate_point": dec_cfg.get('RE_POINT', False), "num_mask_tokens": dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3), "ade_on": dec_cfg.get('ADE', False), "interactive_pretrain": dec_cfg.get('pretrain', False), "match_loss": dec_cfg.get('match_loss', True), "vis_out": os.path.join(cfg.get('OUTPUT_DIR', 'out'), str(cfg.get('VIS_OUT', 'vis'))), "coco_old": cfg.get("coco_old", True), # "points_per_side_eval": cfg.get("points_per_side_eval", 30), "clip_on": cfg.get("clip", False), } @property def device(self): return self.pixel_mean.device def evaluate_demo(self, batched_inputs, all_whole=None, all_parts=None, mask_features=None, multi_scale_features=None, return_features=False): assert len(batched_inputs) == 1, "only support batch size equal to 1" prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True} images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) targets = batched_inputs[0]['targets'] height = images[0].shape[1] width = images[0].shape[2] padded_h = images.tensor.shape[-2] # divisable to 32 padded_w = images.tensor.shape[-1] targets[0]['points'] = targets[0]['points'] * torch.as_tensor([width, height, width, height], dtype=torch.float, device=self.device) / torch.as_tensor( [padded_w, padded_h, padded_w, padded_h], dtype=torch.float, device=self.device) features = self.backbone(images.tensor) mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features( features, None) outputs, mask_dict = self.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets, target_queries=None, target_vlp=None, task='demo', extra=prediction_switch) pred_ious = None if 'pred_ious' in outputs.keys(): pred_ious = outputs["pred_ious"] _, index = pred_ious.view(-1, 3).max(1) index = torch.zeros_like(index) obj_feats = outputs['obj_features'][0].view(-1, self.num_mask_tokens, 256) obj_feats = torch.gather(obj_feats, 1, index[..., None, None].repeat(1, 1, 256))[:, 0] mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results.float(), size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) mask_pred_results = mask_pred_results.view(-1, self.num_mask_tokens, images.tensor.shape[-2], images.tensor.shape[-1]) mask_pred_results = torch.gather(mask_pred_results, 1, index[..., None, None, None].repeat(1, 1, images.tensor.shape[-2], images.tensor.shape[-1])) pred_masks = mask_pred_results[:, 0] image_size = images.image_sizes[0] height = image_size[0] width = image_size[1] if self.sem_seg_postprocess_before_inference: pred_masks = retry_if_cuda_oom(sem_seg_postprocess)( pred_masks, image_size, height, width ) return pred_masks, pred_ious, self.obj_projector(obj_feats) def forward(self, batched_inputs, inference_task='seg',detach=False): if self.training: obj_feats,inter_losses= self.forward_det_pretrain(batched_inputs) for k in list(inter_losses.keys()): if k in self.criterion.weight_dict: inter_losses[k] *= self.criterion.weight_dict[k] # losses[k] *= scale else: # remove this loss if not specified in `weight_dict` inter_losses.pop(k) new_losses = {} for key, value in inter_losses.items(): new_losses['inter' + '.' + str(key)] = inter_losses[key] if detach: return [self.obj_projector(feat.detach()) for feat in obj_feats],new_losses else: return [self.obj_projector(feat)[0] for feat in obj_feats],new_losses else: return self.evaluate_demo(batched_inputs) def forward_det_pretrain(self, batched_inputs, task='seg', prediction_switch={'part': True, 'whole': True, 'seg': True, 'det': True}, dataname='coco', semantic=False): images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features( features, None) if self.clip_on: image = \ preprocess.preprocess(Image.fromarray(batched_inputs[0]['image_ori']), return_tensors='pt')['pixel_values'][ 0] prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True} # self.criterion.num_classes = len(train_class_names) train_class_names_part = None # if prediction_switch['part']: # train_class_names_part = self.train_class_names[dataname + '_part'] # self.criterion.num_classes_part = len(train_class_names) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets_sam(gt_instances, images, prediction_switch=prediction_switch) else: targets = None print("empty targets", targets, task) if prediction_switch['whole']: prediction_switch['whole'] = False if prediction_switch['part']: prediction_switch['part'] = False # tgt_temp=[] obj_features_ls=[] losses_total=None num_masks=0 for i,tgt in enumerate(targets): tgt_temp=[tgt] outputs_gt, mask_dict = self.sem_seg_head.predictor([feat[i:i+1] for feat in multi_scale_features], mask_features[i:i+1], None, targets=tgt_temp, target_queries=None, target_vlp=None, task='seg', extra=prediction_switch) self.criterion.index = torch.zeros_like(batched_inputs[i]['instances'].gt_classes).to(outputs_gt['obj_features'].device) losses, index = self.criterion(outputs_gt, tgt_temp, mask_dict, task='seg', extra=prediction_switch, return_idx=True) index=self.criterion.index bs, n, h, w = outputs_gt["pred_masks"].shape obj_features = outputs_gt['obj_features'].view(bs, -1, self.num_mask_tokens, 256) obj_features = torch.gather(obj_features, 2, index[None][..., None, None].repeat(1, 1, 1, 256))[:,:, 0] # mask_pred_results = outputs_gt["pred_masks"][0].view(-1, self.num_mask_tokens, h, w) obj_features_ls.append(obj_features) num_masks+=losses['num_masks'] if losses_total is None: losses_total=dict() for key in losses.keys(): if key != 'num_masks': losses_total[key] = losses[key] * losses['num_masks'] else: for key in losses.keys(): if key != 'num_masks': losses_total[key]+=losses[key]*losses['num_masks'] for key in losses_total.keys(): if key != 'num_masks': losses_total[key] = losses_total[key]/num_masks return obj_features_ls,losses_total def prepare_targets_sam(self, targets, images, prediction_switch, task='seg', min_box=0.33, max_box=1.0): h_pad, w_pad = images.tensor.shape[-2:] new_targets = [] # box_start = random.randint(int((self.max_num_instance - 1)/2), self.max_num_instance - 1) # box based interactive after this number; about 1/4 # if random.random()<0.5 and self.dbg: # # import pdb;pdb.set_trace() # targets[0]=targets[0][:0] if not self.dbg: self.empty_targets = targets self.dbg = True if len(targets[0]) == 0: empty = True targets = self.empty_targets else: empty = False for targets_per_image in targets: gt_boxes = targets_per_image.gt_boxes if torch.is_tensor( targets_per_image.gt_boxes) else targets_per_image.gt_boxes.tensor # empty=len(gt_boxes)==0 assert len(gt_boxes)>0 self.max_num_instance = len(gt_boxes) box_start = random.randint(int(self.max_num_instance * min_box), int(self.max_num_instance * max_box)) # pad gt h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) gt_masks = targets_per_image.gt_masks if torch.is_tensor( targets_per_image.gt_masks) else targets_per_image.gt_masks.tensor padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks num_mask = targets_per_image.gt_classes.shape[0] index = torch.arange(num_mask) if self.max_num_instance > num_mask: rep = 0 if num_mask == 0 else int(self.max_num_instance / num_mask) + 1 index = index.repeat(rep) index = index[:self.max_num_instance] # if self.regenerate_point and box_start > 0: point_coords = [] if box_start > 0: for i in range(box_start): mask = gt_masks[index[i]].clone() candidate_indices = mask.nonzero() if len(candidate_indices) == 0: print('wrong') selected_point = torch.tensor([0, 0]).cuda() else: selected_index = random.randint(0, len(candidate_indices) - 1) selected_point = candidate_indices[selected_index].flip(0) selected_point = torch.cat([selected_point - 3, selected_point + 3], 0) point_coords.append(selected_point) point_coords = torch.stack(point_coords).to('cuda') # else: # point_coords = targets_per_image.point_coords[index[:box_start]] # point_coords = targets_per_image.gt_boxes.tensor[index[:box_start]] new_target = { "ori_mask_num": len(targets_per_image.gt_classes), "labels": targets_per_image.gt_classes[index] if prediction_switch['whole'] else None, "masks": padded_masks[index], "boxes": box_ops.box_xyxy_to_cxcywh(gt_boxes[index]) / image_size_xyxy, "points": box_ops.box_xyxy_to_cxcywh(point_coords) / image_size_xyxy if len(point_coords) > 0 else None, # "pb":torch.randint(2,(min(self.max_num_instance,len(targets_per_image.gt_classes)),),device=gt_masks.device), "pb": torch.cat([torch.zeros(box_start), torch.ones(self.max_num_instance - box_start)], 0), "gt_whole_classes": targets_per_image.gt_whole_classes[index] if targets_per_image.has( 'gt_whole_classes') and prediction_switch['whole'] else None, "gt_part_classes": targets_per_image.gt_part_classes[index] if targets_per_image.has( 'gt_part_classes') and prediction_switch['part'] else None, } # handle coco data format if prediction_switch['whole'] and not prediction_switch['part']: new_target['gt_whole_classes'] = targets_per_image.gt_classes[index] if new_target["points"] is not None: new_target["boxes_dn"] = torch.cat([new_target["points"], new_target["boxes"][box_start:]], 0) else: new_target["boxes_dn"] = new_target["boxes"][box_start:] new_target['box_start'] = box_start new_target['empty'] = empty new_targets.append(new_target) return new_targets @register_model def get_segmentation_model(cfg, **kwargs): return SemanticSAM(cfg) ================================================ FILE: llava/model/semsam/architectures/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/backbone/__init__.py ================================================ from .build import build_backbone from .focal import * from .focal_dw import * from .swin import * from .backbone import * ================================================ FILE: llava/model/semsam/backbone/backbone.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch.nn as nn from detectron2.modeling import ShapeSpec # from ..layers import ShapeSpec __all__ = ["Backbone"] class Backbone(nn.Module): """ Abstract base class for network backbones. """ def __init__(self): """ The `__init__` method of any subclass can specify its own set of arguments. """ super().__init__() def forward(self): """ Subclasses must override this method, but adhere to the same return type. Returns: dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor """ pass @property def size_divisibility(self) -> int: """ Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required. """ return 0 def output_shape(self): """ Returns: dict[str->ShapeSpec] """ # this is a backward-compatible default return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } ================================================ FILE: llava/model/semsam/backbone/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .backbone import * def build_backbone(config, **kwargs): model_name = config['MODEL']['BACKBONE']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/backbone/focal.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) if self.use_postln: x = self.norm1(x) # FFN x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 2; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/semsam/backbone/focal_dw.py ================================================ # -------------------------------------------------------- # FocalNet for Semantic Segmentation # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang # -------------------------------------------------------- import math import time import numpy as np import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.utils.file_io import PathManager from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FocalModulation(nn.Module): """ Focal Modulation Args: dim (int): Number of input channels. proj_drop (float, optional): Dropout ratio of output. Default: 0.0 focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 focal_factor (int, default=2): Step to increase the focal window use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): super().__init__() self.dim = dim # specific args for focalv3 self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor self.use_postln_in_modulation = use_postln_in_modulation self.scaling_modulator = scaling_modulator self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size//2, bias=False), nn.GELU(), ) ) def forward(self, x): """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] if self.scaling_modulator: ctx_all = ctx_all / (self.focal_level + 1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() if self.use_postln_in_modulation: x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm focal_level (int): number of focal levels focal_window (int): focal kernel size at level 1 """ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, focal_level=2, focal_window=9, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, layerscale_value=1e-4): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level self.use_postln = use_postln self.use_layerscale = use_layerscale self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.norm1 = norm_layer(dim) self.modulation = FocalModulation( dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator ) self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None self.gamma_1 = 1.0 self.gamma_2 = 1.0 if self.use_layerscale: self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw1(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) shortcut = x if not self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C) # FM x = self.modulation(x).view(B, H * W, C) x = shortcut + self.drop_path(self.gamma_1 * x) if self.use_postln: x = self.norm1(x) x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() x = x + self.dw2(x) x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) if not self.use_postln: x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_2 * self.mlp(x)) x = self.norm2(x) return x class BasicLayer(nn.Module): """ A basic focal modulation layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None focal_level (int): Number of focal levels focal_window (int): Focal window size at focal level 1 use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, dim, depth, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, focal_window=9, focal_level=2, use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, use_pre_norm=False, ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ FocalModulationBlock( dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, focal_window=focal_window, focal_level=focal_level, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( patch_size=2, in_chans=dim, embed_dim=2*dim, use_conv_embed=use_conv_embed, norm_layer=norm_layer, is_stem=False, use_pre_norm=use_pre_norm ) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W # class PatchEmbed(nn.Module): # r""" Image to Patch Embedding # Args: # img_size (int): Image size. Default: 224. # patch_size (int): Patch token size. Default: 4. # in_chans (int): Number of input image channels. Default: 3. # embed_dim (int): Number of linear projection output channels. Default: 96. # norm_layer (nn.Module, optional): Normalization layer. Default: None # """ # def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, # use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False): # super().__init__() # patch_size = to_2tuple(patch_size) # patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # self.img_size = img_size # self.patch_size = patch_size # self.patches_resolution = patches_resolution # self.num_patches = patches_resolution[0] * patches_resolution[1] # self.in_chans = in_chans # self.embed_dim = embed_dim # self.use_pre_norm = use_pre_norm # if use_conv_embed: # # if we choose to use conv embedding, then we treat the stem and non-stem differently # if is_stem: # kernel_size = 7; padding = 3; stride = 4 # else: # kernel_size = 3; padding = 1; stride = 2 # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) # else: # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # if self.use_pre_norm: # if norm_layer is not None: # self.norm = norm_layer(in_chans) # else: # self.norm = None # else: # if norm_layer is not None: # self.norm = norm_layer(embed_dim) # else: # self.norm = None # def forward(self, x): # B, C, H, W = x.shape # # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # if self.use_pre_norm: # if self.norm is not None: # x = x.flatten(2).transpose(1, 2) # B Ph*Pw C # x = self.norm(x).transpose(1, 2).view(B, C, H, W) # x = self.proj(x).flatten(2).transpose(1, 2) # else: # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C # if self.norm is not None: # x = self.norm(x) # return x # def flops(self): # Ho, Wo = self.patches_resolution # flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # if self.norm is not None: # flops += Ho * Wo * self.embed_dim # return flops class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False is_stem (bool): Is the stem block or not. """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.use_pre_norm = use_pre_norm if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: kernel_size = 7; padding = 3; stride = 4 else: kernel_size = 3; padding = 1; stride = 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if self.use_pre_norm: if norm_layer is not None: self.norm = norm_layer(in_chans) else: self.norm = None else: if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" B, C, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) if self.use_pre_norm: if self.norm is not None: x = x.flatten(2).transpose(1, 2) # B Ph*Pw C x = self.norm(x).transpose(1, 2).view(B, C, H, W) x = self.proj(x) else: x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class FocalNet(nn.Module): """ FocalNet backbone. Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. focal_levels (Sequence[int]): Number of focal levels at four stages focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=1600, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, out_indices=[0, 1, 2, 3], frozen_stages=-1, focal_levels=[2,2,2,2], focal_windows=[9,9,9,9], use_pre_norms=[False, False, False, False], use_conv_embed=False, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False, use_layerscale=False, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, focal_window=focal_windows[i_layer], focal_level=focal_levels[i_layer], use_pre_norm=use_pre_norms[i_layer], use_conv_embed=use_conv_embed, use_postln=use_postln, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator, use_layerscale=use_layerscale, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # self.norm = norm_layer(num_features[-1]) # add a norm layer for each output for i_layer in self.out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] logger.info(f'=> Missed keys {missed_dict}') unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] logger.info(f'=> Unexpected keys {unexpected_dict}') pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): table_pretrained = v table_current = model_dict[k] fsize1 = table_pretrained.shape[2] fsize2 = table_current.shape[2] # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv if fsize1 < fsize2: table_pretrained_resized = torch.zeros(table_current.shape) table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained v = table_pretrained_resized elif fsize1 > fsize2: table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] v = table_pretrained_resized if ("modulation.f" in k or "pre_conv" in k): table_pretrained = v table_current = model_dict[k] if table_pretrained.shape != table_current.shape: if len(table_pretrained.shape) == 2: dim = table_pretrained.shape[1] assert table_current.shape[1] == dim L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError elif len(table_pretrained.shape) == 1: dim = table_pretrained.shape[0] L1 = table_pretrained.shape[0] L2 = table_current.shape[0] if L1 < L2: table_pretrained_resized = torch.zeros(table_current.shape) # copy for linear project table_pretrained_resized[:dim] = table_pretrained[:dim] # copy for global token gating table_pretrained_resized[-1] = table_pretrained[-1] # copy for first multiple focal levels # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] # reassign pretrained weights v = table_pretrained_resized elif L1 > L2: raise NotImplementedError need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" tic = time.time() x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() toc = time.time() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(FocalNet, self).train(mode) self._freeze_stages() class D2FocalNet(FocalNet, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] in_chans = 3 embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] norm_layer = nn.LayerNorm patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, mlp_ratio, drop_rate, drop_path_rate, norm_layer, patch_norm, out_indices, focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], scaling_modulator=scaling_modulator, use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], use_checkpoint=use_checkpoint, ) self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 @register_backbone def get_focal_backbone(cfg): focal = D2FocalNet(cfg['MODEL'], 224) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] logger.info(f'=> init from {filename}') with PathManager.open(filename, "rb") as f: ckpt = torch.load(f)['model'] focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return focal ================================================ FILE: llava/model/semsam/backbone/registry.py ================================================ _model_entrypoints = {} def register_backbone(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/backbone/swin.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import Backbone, ShapeSpec from detectron2.utils.file_io import PathManager from .registry import register_backbone logger = logging.getLogger(__name__) class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" # HACK model will not upsampling # if min([H, W]) <= self.window_size: # if window size is larger than input resolution, we don't partition windows # self.shift_size = 0 # self.window_size = min([H,W]) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ).type(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) and 'relative_position_index' not in k and 'attn_mask' not in k ) if need_init: # if verbose: # logger.info(f'=> init {k} from {pretrained}') if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): relative_position_bias_table_pretrained = v relative_position_bias_table_current = model_dict[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): absolute_pos_embed_pretrained = v absolute_pos_embed_current = model_dict[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((1, L1, C1), (1, L2, C2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out if len(self.out_indices) == 0: outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint): super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) self._out_features = cfg['OUT_FEATURES'] self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features)) return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in feature_names } @property def size_divisibility(self): return 32 @register_backbone def get_swin_backbone(cfg): swin_cfg = cfg['MODEL']['BACKBONE']['SWIN'] pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE'] patch_size = swin_cfg['PATCH_SIZE'] in_chans = 3 embed_dim = swin_cfg['EMBED_DIM'] depths = swin_cfg['DEPTHS'] num_heads = swin_cfg['NUM_HEADS'] window_size = swin_cfg['WINDOW_SIZE'] mlp_ratio = swin_cfg['MLP_RATIO'] qkv_bias = swin_cfg['QKV_BIAS'] qk_scale = swin_cfg['QK_SCALE'] drop_rate = swin_cfg['DROP_RATE'] attn_drop_rate = swin_cfg['ATTN_DROP_RATE'] drop_path_rate = swin_cfg['DROP_PATH_RATE'] norm_layer = nn.LayerNorm ape = swin_cfg['APE'] patch_norm = swin_cfg['PATCH_NORM'] use_checkpoint = swin_cfg['USE_CHECKPOINT'] out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3]) swin = D2SwinTransformer( swin_cfg, pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, out_indices, use_checkpoint=use_checkpoint, ) if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] with PathManager.open(filename, "rb") as f: # ckpt = torch.load(f, map_location=cfg['device'])['model'] ckpt = torch.load(f, map_location='cpu')['model'] swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) return swin ================================================ FILE: llava/model/semsam/backbone/swin_new.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.H = None self.W = None def forward(self, x, mask_matrix): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, num_heads, window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], ] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs["res{}".format(i + 2)] = out return outs def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() @BACKBONE_REGISTRY.register() class D2SwinTransformer(SwinTransformer, Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE patch_size = cfg.MODEL.SWIN.PATCH_SIZE in_chans = 3 embed_dim = cfg.MODEL.SWIN.EMBED_DIM depths = cfg.MODEL.SWIN.DEPTHS num_heads = cfg.MODEL.SWIN.NUM_HEADS window_size = cfg.MODEL.SWIN.WINDOW_SIZE mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO qkv_bias = cfg.MODEL.SWIN.QKV_BIAS qk_scale = cfg.MODEL.SWIN.QK_SCALE drop_rate = cfg.MODEL.SWIN.DROP_RATE attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE norm_layer = nn.LayerNorm ape = cfg.MODEL.SWIN.APE patch_norm = cfg.MODEL.SWIN.PATCH_NORM use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT super().__init__( pretrain_img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, use_checkpoint=use_checkpoint, ) self._out_features = cfg.MODEL.SWIN.OUT_FEATURES self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32 ================================================ FILE: llava/model/semsam/body/__init__.py ================================================ from .build import build_openseed_head ================================================ FILE: llava/model/semsam/body/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .openseed_head import * def build_openseed_head(config, *args, **kwargs): model_name = config['MODEL']['HEAD'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') body = model_entrypoints(model_name)(config, *args, **kwargs) return body ================================================ FILE: llava/model/semsam/body/decoder/__init__.py ================================================ from .build import build_decoder from .idino_decoder_no_iou_token_partwhole_all_llm import * ================================================ FILE: llava/model/semsam/body/decoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_decoder(config, *args, **kwargs): model_name = config['MODEL']['DECODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/semsam/body/decoder/idino_decoder_no_iou_token_partwhole_all_llm.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. import logging import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from detectron2.structures import BitMasks from timm.models.layers import trunc_normal_ from .registry import register_decoder from .utils.dino_decoder import TransformerDecoder, DeformableTransformerDecoderLayer from .utils import MLP, gen_encoder_output_proposals, inverse_sigmoid from ...utils import box_ops from ...utils import configurable class MaskDINODecoder(nn.Module): @configurable def __init__( self, lang_encoder: nn.Module, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, dim_proj: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, mask_dim: int, enforce_input_project: bool, two_stage: bool, dn: str, noise_scale:float, dn_num:int, initialize_box_type:bool, initial_pred:bool, learn_tgt: bool, total_num_feature_levels: int = 4, dropout: float = 0.0, activation: str = 'relu', nhead: int = 8, dec_n_points: int = 4, return_intermediate_dec: bool = True, query_dim: int = 4, dec_layer_share: bool = False, semantic_ce_loss: bool = False, num_mask_tokens: int = 3, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical d_model: transformer dimension dropout: dropout rate activation: activation function nhead: num heads in multi-head attention dec_n_points: number of sampling points in decoder return_intermediate_dec: return the intermediate results of decoder query_dim: 4 -> (x, y, w, h) dec_layer_share: whether to share each decoder layer semantic_ce_loss: use ce loss for semantic segmentation """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification self.num_feature_levels = total_num_feature_levels self.initial_pred = initial_pred # define Transformer decoder here self.dn=dn self.learn_tgt = learn_tgt self.noise_scale=noise_scale self.dn_num=dn_num self.num_heads = nheads self.num_layers = dec_layers self.two_stage=two_stage self.initialize_box_type = initialize_box_type self.total_num_feature_levels = total_num_feature_levels self.num_queries = num_queries self.semantic_ce_loss = semantic_ce_loss interactive_only = True # learnable query features if num_queries>0 and not interactive_only: if not two_stage or self.learn_tgt: self.query_feat = nn.Embedding(num_queries, hidden_dim) if not two_stage and initialize_box_type == 'no': self.query_embed = nn.Embedding(num_queries, 4) # if two_stage: # self.enc_output = nn.Linear(hidden_dim, hidden_dim) # self.enc_output_norm = nn.LayerNorm(hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) self.num_classes=num_classes # output FFNs assert self.mask_classification, "why not class embedding?" # self.label_enc=nn.Embedding(505, hidden_dim) # this is a hack for o365+coco (365+133=498) self.dim_proj = dim_proj self.lang_encoder = lang_encoder # if lang_encoder is not None: self.lang_mapper = nn.Parameter(torch.empty(dim_proj, hidden_dim)) trunc_normal_(self.lang_mapper, std=.02) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # init decoder self.decoder_norm = decoder_norm = nn.LayerNorm(hidden_dim) decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, dim_feedforward, dropout, activation, self.num_feature_levels, nhead, dec_n_points) self.decoder = TransformerDecoder(decoder_layer, self.num_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=hidden_dim, query_dim=query_dim, num_feature_levels=self.num_feature_levels, dec_layer_share=dec_layer_share, ) self.hidden_dim = hidden_dim self._bbox_embed = _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) box_embed_layerlist = [_bbox_embed for i in range(self.num_layers)] # share box prediction each layer self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.decoder.bbox_embed = self.bbox_embed # whole category classification self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed, std=.02) # part category classification self.class_embed_part = nn.Parameter(torch.empty(hidden_dim, dim_proj)) trunc_normal_(self.class_embed_part, std=.02) # FIXME iou head; iou prediction: 1. iou token to predict 3 score. 2. predict each iou score from query tokens # FIXME seems we only need to stack these tokens in batch dimension to reduce self attention burden. self.num_mask_tokens = num_mask_tokens # sam uses 4 to handle multi prompts self.iou_token = 0 # FIXME hack to remove iou token self.num_all_tokens = self.num_mask_tokens + self.iou_token # sam uses 4 to handle multi prompts self.iou_prediction_head = MLP(hidden_dim, hidden_dim, 1, 3) # self.iou_token = nn.Embedding(self.iou_token, hidden_dim) self.mask_tokens = nn.Embedding(self.num_mask_tokens, hidden_dim) self.pb_embedding=nn.Embedding(2,hidden_dim) self.label_enc=nn.Embedding(2,hidden_dim) self.prediction_switch = None @classmethod def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): ret = {} ret["in_channels"] = in_channels ret["lang_encoder"] = lang_encoder ret["mask_classification"] = mask_classification enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret["num_classes"] = enc_cfg['NUM_CLASSES'] ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] # Transformer parameters: ret["num_mask_tokens"] = dec_cfg.get('NUM_MASK_TOKENS', 3) ret["nheads"] = dec_cfg['NHEADS'] ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["dec_layers"] = dec_cfg['DEC_LAYERS'] ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["two_stage"] = dec_cfg['TWO_STAGE'] ret["initialize_box_type"] = dec_cfg['INITIALIZE_BOX_TYPE'] # ['no', 'bitmask', 'mask2box'] ret["dn"] = dec_cfg['DN'] ret["noise_scale"] = dec_cfg['DN_NOISE_SCALE'] ret["dn_num"] = dec_cfg['DN_NUM'] ret["initial_pred"] = dec_cfg['INITIAL_PRED'] ret["learn_tgt"] = dec_cfg['LEARN_TGT'] ret["total_num_feature_levels"] = dec_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_mask_tokens"] = dec_cfg.get('NUM_INTERACTIVE_TOKENS', 3) ret["semantic_ce_loss"] = dec_cfg['TEST']['SEMANTIC_ON'] and dec_cfg['SEMANTIC_CE_LOSS'] and not dec_cfg['TEST']['PANOPTIC_ON'] return ret def prepare_for_dn(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = scalar // (int(max(known_num))) else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = torch.gather(self.lang_encoder.default_text_embeddings, 0, m[:, None].repeat(1, self.dim_proj)) @ self.lang_mapper input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def prepare_for_dn_o3(self, targets, tgt, refpoint_emb, batch_size): """ modified from dn-detr. You can refer to dn-detr https://github.com/IDEA-Research/DN-DETR/blob/main/models/dn_dab_deformable_detr/dn_components.py for more details :param dn_args: scalar, noise_scale :param tgt: original tgt (content) in the matching part :param refpoint_emb: positional anchor queries in the matching part :param batch_size: bs """ if self.training: scalar, noise_scale = self.dn_num, self.noise_scale known = [(torch.ones_like(t['labels'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [sum(k) for k in known] # use fix number of dn queries if max(known_num) > 0: scalar = 1 else: scalar = 0 if scalar == 0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None return input_query_label, input_query_bbox, attn_mask, mask_dict # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) # use languge as denosing content queries. # if task == 'det': # labels = labels # o365 start from 133 class boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) # known known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # noise known_indice = known_indice.repeat(scalar, 1).view(-1) known_labels = labels.repeat(scalar, 1).view(-1) known_bid = batch_idx.repeat(scalar, 1).view(-1) known_bboxs = boxes.repeat(scalar, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0: diff = torch.zeros_like(known_bbox_expand) diff[:, :2] = known_bbox_expand[:, 2:] / 2 diff[:, 2:] = known_bbox_expand[:, 2:] known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') # import ipdb; ipdb.set_trace() input_label_embed = self.pb_embedding(torch.ones_like(m)) input_bbox_embed = inverse_sigmoid(known_bbox_expand) single_pad = int(max(known_num)) pad_size = int(single_pad * scalar) padding_label = input_label_embed.new_zeros(pad_size, self.hidden_dim) padding_bbox = input_bbox_embed.new_zeros(pad_size, 4) if not refpoint_emb is None: input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) else: input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) # map map_known_indice = input_label_embed.new_tensor([]) if len(known_num): map_known_indice = torch.cat( [input_label_embed.new_tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = input_label_embed.new_ones(tgt_size, tgt_size) < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } else: if not refpoint_emb is None: input_query_label = tgt.repeat(batch_size, 1, 1) input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) else: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label, input_query_bbox, attn_mask, mask_dict def prepare_for_dn_mo(self, targets, tgt, refpoint_emb, batch_size): # if self.training: scalar, noise_scale = self.dn_num,self.noise_scale known = [(torch.ones_like(t['boxes'])).cuda() for t in targets] know_idx = [torch.nonzero(t) for t in known] known_num = [k.sum() for k in known] if max(known_num)>0: scalar=1 # FIXME this is wrong attention mask!!! else: scalar=0 if scalar==0: input_query_label = None input_query_bbox = None attn_mask = None mask_dict = None # return input_query_label, input_query_bbox, attn_mask, mask_dict pb_labels = torch.stack([t['pb'] for t in targets]) # FIXME this is for future content-based interaction; pool content features as label embedding labels = torch.zeros_like(pb_labels).long() boxes = torch.stack([t['boxes_dn'] for t in targets]) box_start = [t['box_start'] for t in targets] known_labels = labels known_pb_labels = pb_labels known_bboxs = boxes known_labels_expaned = known_labels.clone() known_pb_labels_expaned = known_pb_labels.clone() known_bbox_expand = known_bboxs.clone() ############ noise on the label # if noise_scale > 0: # p = torch.rand_like(known_labels_expaned.float()) # chosen_indice = torch.nonzero(p < (noise_scale * 0.5)).view(-1) # half of bbox prob # new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here # known_labels_expaned.scatter_(0, chosen_indice, new_label) if noise_scale > 0 and self.training: diff = torch.zeros_like(known_bbox_expand) diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 diff[:, :, 2:] = known_bbox_expand[:, :, 2:] sc = 0.01 for i, st in enumerate(box_start): diff[i, :st] = diff[i, :st] * sc known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * noise_scale # known_bbox_expand+=(torch.rand_like(known_bbox_expand)*2-1.0)*torch.tensor([[1,1,0.1,0.1]]).cuda()*noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) m = known_labels_expaned.long().to('cuda') m_pb = known_pb_labels_expaned.long().to('cuda') input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb) input_bbox_embed = inverse_sigmoid(known_bbox_expand) input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1) input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1) single_pad = self.num_all_tokens # NOTE scalar is modified to 100, each click cannot see each other scalar = int(input_label_embed.shape[1]/self.num_all_tokens) pad_size = input_label_embed.shape[1] if input_label_embed.shape[1]>0: input_query_label = input_label_embed input_query_bbox = input_bbox_embed tgt_size = pad_size attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_lbs_bboxes': (known_labels, known_bboxs), # 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } # 100*batch*256 if not input_query_bbox is None: input_query_label = input_query_label input_query_bbox = input_query_bbox return input_query_label,input_query_bbox,attn_mask,mask_dict def prepare_for_dn_mo_infer(self, targets, tgt, refpoint_emb, batch_size): known = [(torch.ones_like(t['points'])).cuda() for t in targets] known_num = [k.sum() for k in known] assert max(known_num)>0 pb_labels = torch.stack([t['pb'] for t in targets]) # FIXME this is for future content-based interaction; pool content features as label embedding labels = torch.zeros_like(pb_labels).long() boxes = torch.stack([t['points'] for t in targets]) known_labels = labels known_pb_labels = pb_labels known_bboxs = boxes known_labels_expaned = known_labels.clone() known_pb_labels_expaned = known_pb_labels.clone() known_bbox_expand = known_bboxs.clone() m = known_labels_expaned.long().to('cuda') m_pb = known_pb_labels_expaned.long().to('cuda') input_label_embed = self.label_enc(m)+self.pb_embedding(m_pb) input_bbox_embed = inverse_sigmoid(known_bbox_expand) input_label_embed = input_label_embed.repeat_interleave(self.num_all_tokens,1) + self.mask_tokens.weight.unsqueeze(0).repeat(input_label_embed.shape[0], input_label_embed.shape[1], 1) input_bbox_embed = input_bbox_embed.repeat_interleave(self.num_all_tokens,1) scalar = int(input_label_embed.shape[1]/self.num_all_tokens) pad_size = input_label_embed.shape[1] if input_label_embed.shape[1]>0: input_query_label = input_label_embed input_query_bbox = input_bbox_embed attn_mask = None mask_dict = { 'known_lbs_bboxes': (known_labels, known_bboxs), # 'know_idx': know_idx, 'pad_size': pad_size, 'scalar': scalar, } return input_query_label,input_query_bbox,attn_mask,mask_dict def dn_post_process(self,outputs_class,outputs_coord,mask_dict,outputs_mask): """ post process of dn after output from the transformer put the dn part in the mask_dict """ assert mask_dict['pad_size'] > 0 output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] output_known_mask = None if outputs_mask is not None: output_known_mask = outputs_mask[:, :, :mask_dict['pad_size'], :] outputs_mask = outputs_mask[:, :, mask_dict['pad_size']:, :] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1],'pred_masks': None if output_known_mask is None else output_known_mask[-1]} out['aux_outputs'] = self._set_aux_loss(output_known_class, output_known_mask,output_known_coord) mask_dict['output_known_lbs_bboxes']=out return outputs_class, outputs_coord, outputs_mask def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def pred_box(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) # layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) new_layer_ref_sig = layer_ref_sig.view(layer_ref_sig.shape[0], -1, self.num_all_tokens, layer_ref_sig.shape[-1]) new_layer_ref_sig = new_layer_ref_sig[:, :, :self.num_mask_tokens].reshape(new_layer_ref_sig.shape[0], -1, new_layer_ref_sig.shape[-1]) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(new_layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def pred_box_old(self, reference, hs, ref0=None): """ :param reference: reference box coordinates from each decoder layer :param hs: content :param ref0: whether there are prediction from the first layer """ if ref0 is None: outputs_coord_list = [] else: outputs_coord_list = [ref0] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) return outputs_coord_list def forward(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det TODO add sam """ # task = 'sam' prediction_switch = extra self.prediction_switch = prediction_switch assert len(x) == self.num_feature_levels do_seg = (task != 'det') # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] predictions_class_part = [] predictions_mask = [] predictions_iou_score = [] tgt_mask = None mask_dict = None if self.dn != "no": assert targets is not None if task=='demo': input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_mo_infer(targets, None, None, x[0].shape[0]) else: input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_mo(targets, None, None, x[0].shape[0]) tgt=input_query_label refpoint_embed=input_query_bbox if tgt is None: tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda() refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda() # import pdb;pdb.set_trace() refpoint_embed=refpoint_embed.to(tgt.dtype) hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) new_hs = [] feats=[] for i, output in enumerate(hs): outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output = self.idno_forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) outputs_class_whole, outputs_class_part = outputs_class predictions_class.append(outputs_class_whole) predictions_class_part.append(outputs_class_part) predictions_mask.append(outputs_mask) feats.append(decoder_output) if iou_score is not None: predictions_iou_score.append(iou_score) new_hs.append(decoder_output_mask) if new_hs is not None: hs = new_hs # iteratively box prediction out_boxes = self.pred_box(references, hs) out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum() + self.mask_tokens.weight.sum() + self.lang_mapper.sum()+iou_score.sum()) if mask_dict is not None: if predictions_mask is None: predictions_class[-1] = predictions_class[-1] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], 'obj_features': feats[-1], 'pred_logits_part': predictions_class_part[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], 'pred_ious': predictions_iou_score[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask, out_boxes, predictions_iou_score, predictions_class_part ) } return out, mask_dict def forward_o365(self, x, mask_features, masks, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): """ task: seg/det TODO add sam """ # task = 'sam' prediction_switch = extra self.prediction_switch = prediction_switch assert len(x) == self.num_feature_levels do_seg = False # if task is det, not do segmentation training size_list = [] # disable mask, it does not affect performance enable_mask = 0 if masks is not None: for src in x: if src.size(2) % 32 or src.size(3) % 32: enable_mask = 1 if enable_mask == 0: masks = [torch.zeros((src.size(0), src.size(2), src.size(3)), device=src.device, dtype=torch.bool) for src in x] src_flatten = [] mask_flatten = [] spatial_shapes = [] for i in range(self.num_feature_levels): idx=self.num_feature_levels-1-i bs, c , h, w=x[idx].shape size_list.append(x[i].shape[-2:]) spatial_shapes.append(x[idx].shape[-2:]) src_flatten.append(self.input_proj[idx](x[idx]).flatten(2).transpose(1, 2)) mask_flatten.append(masks[i].flatten(1)) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) predictions_class = [] # predictions_class_part = [] predictions_mask = [] # predictions_iou_score = [] tgt_mask = None mask_dict = None # if self.dn != "no": assert targets is not None input_query_label, input_query_bbox, tgt_mask, mask_dict = \ self.prepare_for_dn_o3(targets, None, None, x[0].shape[0]) tgt=input_query_label refpoint_embed=input_query_bbox if tgt is None: tgt = torch.zeros(bs, self.num_queries, self.hidden_dim).cuda() refpoint_embed = torch.zeros(bs, self.num_queries, 4).cuda() hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=src_flatten.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=None, refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=tgt_mask ) # new_hs = [] for i, output in enumerate(hs): outputs_class, outputs_mask = self.forward_prediction_heads(output.transpose(0, 1), mask_features, (self.training or (i == len(hs)-1)) and do_seg) outputs_class_whole = outputs_class predictions_class.append(outputs_class_whole) # predictions_class_part.append(outputs_class_part) predictions_mask.append(outputs_mask) # if iou_score is not None: # predictions_iou_score.append(iou_score) # new_hs.append(decoder_output_mask) # if new_hs is not None: # hs = new_hs # iteratively box prediction out_boxes = self.pred_box_old(references, hs) out_boxes[-1] = out_boxes[-1] + 0.0 * (self.label_enc.weight.sum() + self.pb_embedding.weight.sum() + self.mask_tokens.weight.sum() + self.lang_mapper.sum()) if mask_dict is not None: if predictions_mask is None: predictions_class[-1] = predictions_class[-1] for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * (self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss if do_seg: predictions_mask = list(predictions_mask) elif self.training: # this is to insure self.label_enc participate in the model for i in range(self.mask_embed.num_layers): predictions_class[-1] = predictions_class[-1] + 0.0 * ( self.mask_embed.layers[i].weight[0][0] + self.mask_embed.layers[i].bias[ 0]) # avoid no mask loss predictions_class[-1] = predictions_class[-1] + 0.0 * mask_features[0][0][0][0] # avoid no mask loss out = { 'pred_logits': predictions_class[-1], # 'pred_logits_part': predictions_class_part[-1], 'pred_masks': None if not do_seg else predictions_mask[-1], 'pred_boxes':out_boxes[-1], # 'pred_ious': predictions_iou_score[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask, out_boxes ) } return out, mask_dict def forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) class_embed = decoder_output @ self.class_embed outputs_class = self.lang_encoder.compute_similarity(class_embed, name='whole') outputs_mask = None if pred_mask: mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) return outputs_class, outputs_mask def idno_forward_prediction_heads(self, output, mask_features, pred_mask=True): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) decoder_output = decoder_output + 0.0 * (self.class_embed_part.sum() + self.class_embed.sum()) out = decoder_output.view(decoder_output.shape[0], -1, self.num_all_tokens, decoder_output.shape[-1]) decoder_output_mask = out[:, :, :self.num_mask_tokens].reshape(decoder_output.shape[0], -1, decoder_output.shape[-1]) # decoder_output_iou = out[:, :, -1].view(decoder_output.shape[0], -1, decoder_output.shape[-1]) decoder_output_iou = decoder_output_mask outputs_mask = outputs_class_whole = outputs_class_part = None if self.prediction_switch['whole']: class_embed_whole = decoder_output @ self.class_embed outputs_class_whole = self.lang_encoder.compute_similarity(class_embed_whole, name='whole') if self.prediction_switch['part']: class_embed_part = decoder_output @ self.class_embed_part outputs_class_part = self.lang_encoder.compute_similarity(class_embed_part, name='part') outputs_class = (outputs_class_whole, outputs_class_part) if self.prediction_switch['seg']: mask_embed = self.mask_embed(decoder_output_mask) if mask_embed.dtype==torch.float16 and mask_features.dtype==torch.float32: mask_embed=mask_embed.to(torch.float32) if mask_embed.dtype==torch.float32 and mask_features.dtype==torch.float16: mask_features=mask_features.to(torch.float32) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features.to(mask_embed.dtype)) iou_score = self.iou_prediction_head(decoder_output_iou).squeeze(-1).view(decoder_output.shape[0], -1, self.num_mask_tokens) # outputs_mask = outputs_mask + 0.0 * iou_score.sum() # TODO add iou prediction head return outputs_class, outputs_mask, iou_score, decoder_output_mask,decoder_output @torch.jit.unused def _set_aux_loss(self, outputs_class=None, outputs_seg_masks=None, out_boxes=None, predictions_iou_score=None, predictions_class_part=None): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. # if self.mask_classification: if out_boxes is None: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] elif outputs_seg_masks is None: return [ {"pred_logits": a, "pred_boxes": c} for a, c in zip(outputs_class[:-1], out_boxes[:-1]) ] elif predictions_iou_score is None: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c} for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], out_boxes[:-1]) ] else: return [ {"pred_logits": a, "pred_masks": b, "pred_boxes":c, "pred_ious":d, "pred_logits_part": e} for a, b, c, d, e in zip(outputs_class[:-1], outputs_seg_masks[:-1],out_boxes[:-1], predictions_iou_score[:-1], predictions_class_part[:-1]) ] @register_decoder def get_maskdino_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): return MaskDINODecoder(cfg, in_channels, lang_encoder, mask_classification, extra) ================================================ FILE: llava/model/semsam/body/decoder/modules.py ================================================ from typing import Optional import torch from torch import nn, Tensor from torch.nn import functional as F from timm.models.layers import trunc_normal_ from detectron2.layers import Conv2d import fvcore.nn.weight_init as weight_init class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt, avg_attn def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) return tgt, avg_attn def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x ================================================ FILE: llava/model/semsam/body/decoder/registry.py ================================================ _model_entrypoints = {} def register_decoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/decoder/utils/__init__.py ================================================ from .utils import * ================================================ FILE: llava/model/semsam/body/decoder/utils/dino_decoder.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Optional, List, Union import torch from torch import nn, Tensor from torch.cuda.amp import autocast from .utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid from ...encoder.ops.modules import MSDeformAttn from torch.utils.checkpoint import checkpoint class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=True, num_feature_levels=1, deformable_decoder=True, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=True, dec_layer_share=False, dec_layer_dropout_prob=None, task_switch=None, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.task_switch = task_switch self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # misc extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] if 'lang_refpoint_embed' in extra.keys() and 'grounding_tokens' in extra.keys(): reference_points = torch.cat((reference_points, extra['lang_refpoint_embed'].transpose(0,1).sigmoid()), dim=0) output = torch.cat((output, extra['grounding_tokens']), dim=0) for layer_id, layer in enumerate(self.layers): # preprocess ref points if self.training and self.decoder_query_perturber is not None and layer_id != 0: reference_points = self.decoder_query_perturber(reference_points) reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :].to(reference_points.dtype) # nq, bs, nlevel, 4 # print('reference_points_input', reference_points_input.dtype) # print('memory', memory.dtype) # reference_points_input=reference_points_input.to(memory.dtype) query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :], dim=output.shape[-1]//2) # nq, bs, 256*2 # import pdb; pdb.set_trace() # query_sine_embed = query_sine_embed.to(self.ref_point_head.layers[0].weight.dtype) raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos output = layer( output, query_pos, query_sine_embed, tgt_key_padding_mask, reference_points_input, memory, memory_key_padding_mask, level_start_index, spatial_shapes, pos, tgt_mask, memory_mask, self.task_switch, extra, ) # grounding language token reference point will not update and saved if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _reference_points = reference_points[-extra['grounding_len']:] reference_points = reference_points[:-extra['grounding_len']] _output = output[-extra['grounding_len']:] output = output[:-extra['grounding_len']] # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) # import pdb; pdb.set_trace() output= output.to(query_sine_embed.dtype) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) # add back grounding language token if (self.task_switch is not None) and (extra is not None) and (self.task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': reference_points = torch.cat((reference_points, _reference_points)) output = torch.cat((output, _output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, key_aware_type=None, ): super().__init__() # cross attention if use_deformable_box_attn: raise NotImplementedError else: self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt @autocast(enabled=True) def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # misc task_switch: Optional[Tensor] = {}, # extra information extra: Optional[Tensor] = {}, # extra information ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention # import pdb;pdb.set_trace() if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # exclude grounding token for cross attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': _grounding_lang_tokens = tgt[-extra['grounding_len']:,] _grounding_lang_pos = tgt_query_pos[-extra['grounding_len']:,] _grounding_ref_points = tgt_reference_points[-extra['grounding_len']:,] tgt = tgt[:-extra['grounding_len'],] tgt_query_pos = tgt_query_pos[:-extra['grounding_len'],] tgt_reference_points = tgt_reference_points[:-extra['grounding_len'],] # cross attention if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) # import pdb;pdb.set_trace() tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) # TODO: check whether add grounding lang token to cross attention is better # import pdb;pdb.set_trace() tgt = tgt + self.dropout1(tgt2) # add back grounding token for self attention if (task_switch is not None) and (extra is not None) and (task_switch['grounding']) and ('grounding_len' in extra) and extra['task']=='seg': tgt = torch.cat((tgt, _grounding_lang_tokens)) tgt = self.norm1(tgt) tgt = self.forward_ffn(tgt) # ffn return tgt ================================================ FILE: llava/model/semsam/body/decoder/utils/utils.py ================================================ import torch import copy from torch import nn, Tensor import os import math import torch.nn.functional as F from torch import nn class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) return output_memory, output_proposals def gen_sineembed_for_position(pos_tensor, dim=128): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * (dim_t // 2) / dim) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos.to(pos_tensor.dtype) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == "prelu": return nn.PReLU() if activation == "selu": return F.selu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) ================================================ FILE: llava/model/semsam/body/encoder/__init__.py ================================================ from .build import build_encoder ================================================ FILE: llava/model/semsam/body/encoder/build.py ================================================ from .registry import model_entrypoints from .registry import is_model from .transformer_encoder_fpn import * from .encoder_deform import * def build_encoder(config, *args, **kwargs): model_name = config['MODEL']['ENCODER']['NAME'] if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, *args, **kwargs) ================================================ FILE: llava/model/semsam/body/encoder/encoder_deform.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified by Feng Li and Hao Zhang. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast from detectron2.layers import Conv2d, ShapeSpec, get_norm # from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_encoder from ...utils import configurable from ...modules import PositionEmbeddingSine from ..transformer_blocks import _get_clones, _get_activation_fn from .ops.modules import MSDeformAttn from torch.utils import checkpoint # MSDeformAttn Transformer encoder in deformable detr class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4,): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, use_ckpt=False): enable_mask=0 if masks is not None: for src in srcs: if src.size(2)%32 or src.size(3)%32: enable_mask = 1 if enable_mask==0: masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, use_ckpt=use_ckpt) return memory, spatial_shapes, level_start_index class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, use_ckpt=False): output = src reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): use_ckpt = False if use_ckpt: output = checkpoint.checkpoint(layer,output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) else: output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class MaskDINOEncoder(nn.Module): """ This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models. """ @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, # deformable transformer encoder args transformer_in_features: List[str], common_stride: int, num_feature_levels: int, total_num_feature_levels: int, feature_order: str, use_ckpt=False, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers num_feature_levels: feature scales used total_num_feature_levels: total feautre scales used (include the downsampled features) feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first. """ super().__init__() self.use_ckpt = use_ckpt transformer_input_shape = { k: v for k, v in input_shape.items() if k in transformer_in_features } # this is the input shape of pixel decoder input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" self.feature_strides = [v.stride for k, v in input_shape] self.feature_channels = [v.channels for k, v in input_shape] self.feature_order = feature_order if feature_order == "low2high": transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride) else: transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" transformer_in_channels = [v.channels for k, v in transformer_input_shape] self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales self.total_num_feature_levels = total_num_feature_levels self.common_stride = common_stride self.transformer_num_feature_levels = len(self.transformer_in_features) self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels)) self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1 if self.transformer_num_feature_levels > 1: input_proj_list = [] for in_channels in transformer_in_channels[::-1]: input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )) # input projectino for downsample in_channels = max(transformer_in_channels) for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2 input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, conv_dim), )) in_channels = conv_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )]) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, num_feature_levels=self.total_num_feature_levels, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) self.mask_dim = mask_dim # use 1x1 conv instead self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=1, stride=1, padding=0, ) weight_init.c2_xavier_fill(self.mask_features) # extra fpn levels stride = min(self.transformer_feature_strides) self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1) lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], *args, **kwargs): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] # deformable transformer encoder ret[ "transformer_enc_layers" ] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_in_features"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES'] # ['res3', 'res4', 'res5'] ret["common_stride"] = enc_cfg['COMMON_STRIDE'] ret["total_num_feature_levels"] = enc_cfg['TOTAL_NUM_FEATURE_LEVELS'] ret["num_feature_levels"] = enc_cfg['NUM_FEATURE_LEVELS'] ret["feature_order"] = enc_cfg['FEATURE_ORDER'] ret["use_ckpt"] = enc_cfg.get('USE_CKPT', False) return ret @autocast(enabled=True) def forward_features(self, features, masks): """ :param features: multi-scale features from the backbone :param masks: image mask :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask """ # backbone features srcs = [] pos = [] # additional downsampled features srcsl = [] posl = [] # import pdb; pdb.set_trace() if self.total_num_feature_levels > self.transformer_num_feature_levels: smallest_feat = features[self.transformer_in_features[self.low_resolution_index]]#.float() _len_srcs = self.transformer_num_feature_levels for l in range(_len_srcs, self.total_num_feature_levels): if l == _len_srcs: src = self.input_proj[l](smallest_feat) else: src = self.input_proj[l](srcsl[-1]) srcsl.append(src) posl.append(self.pe_layer(src)) srcsl = srcsl[::-1] # Reverse feature maps for idx, f in enumerate(self.transformer_in_features[::-1]): x = features[f]#.float() # deformable detr does not support half precision srcs.append(self.input_proj[idx](x)) pos.append(self.pe_layer(x)) srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs) pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos) if self.feature_order != 'low2high': srcs = srcsl pos = posl y, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos, use_ckpt=self.use_ckpt) bs = y.shape[0] split_size_or_sections = [None] * self.total_num_feature_levels for i in range(self.total_num_feature_levels): if i < self.total_num_feature_levels - 1: split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] else: split_size_or_sections[i] = y.shape[1] - level_start_index[i] y = torch.split(y, split_size_or_sections, dim=1) out = [] multi_scale_features = [] num_cur_levels = 0 for i, z in enumerate(y): out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): x = features[f]#.float() lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) y = output_conv(y) out.append(y) for o in out: if num_cur_levels < self.total_num_feature_levels: multi_scale_features.append(o) num_cur_levels += 1 return self.mask_features(out[-1]), out[0], multi_scale_features @register_encoder def get_maskdino_encoder_deform(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = MaskDINOEncoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/semsam/body/encoder/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: llava/model/semsam/body/encoder/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable try: import MultiScaleDeformableAttention as MSDA except ModuleNotFoundError as e: info_string = ( "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" "\t`cd mask2former/modeling/pixel_decoder/ops`\n" "\t`sh make.sh`\n" ) raise ModuleNotFoundError(info_string) class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: llava/model/semsam/body/encoder/ops/make.sh ================================================ #!/usr/bin/env bash # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR python setup.py build install --user ================================================ FILE: llava/model/semsam/body/encoder/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from .ms_deform_attn import MSDeformAttn ================================================ FILE: llava/model/semsam/body/encoder/ops/modules/ms_deform_attn.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 128 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in # input_flatten=input_flatten.to(self.value_proj.bias.data.dtype) value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) # try: # print(value.dtype) convert=False # import pdb; pdb.set_trace() dtype=value.dtype if value.dtype== torch.bfloat16 or value.dtype== torch.float16: value = value.float() attention_weights = attention_weights.float() sampling_locations = sampling_locations.float() convert=True # value = value.float() # attention_weights = attention_weights.float() # sampling_locations = sampling_locations.float() # convert=True output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) if convert: output = output.to(dtype) # except: # # CPU # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) # # For FLOPs calculation only # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output ================================================ FILE: llava/model/semsam/body/encoder/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] # Force cuda since torch ask for a device, not if cuda is in fact available. if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: if CUDA_HOME is None: raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') else: raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: llava/model/semsam/body/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ /*! * Copyright (c) Facebook, Inc. and its affiliates. * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: llava/model/semsam/body/encoder/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: llava/model/semsam/body/encoder/registry.py ================================================ _model_entrypoints = {} def register_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/encoder/transformer_encoder_fpn.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast import fvcore.nn.weight_init as weight_init from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm from .registry import register_encoder from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn from ...modules import PositionEmbeddingSine from ...utils import configurable # This is a modified FPN decoder. class BasePixelDecoder(nn.Module): def __init__( self, input_shape: Dict[str, ShapeSpec], *, conv_dim: int, mask_dim: int, mask_on: bool, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_channels = [v.channels for k, v in input_shape] lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(feature_channels): if idx == len(self.in_features) - 1: output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( in_channels, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(None) output_convs.append(output_conv) else: lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.mask_on = mask_on if self.mask_on: self.mask_dim = mask_dim self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=3, stride=1, padding=1, ) weight_init.c2_xavier_fill(self.mask_features) self.maskformer_num_feature_levels = 3 # always use 3 scales @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] } ret["conv_dim"] = enc_cfg['CONVS_DIM'] ret["mask_dim"] = enc_cfg['MASK_DIM'] ret["norm"] = enc_cfg['NORM'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: y = output_conv(x) else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, None, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) class TransformerEncoderOnly(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) if mask is not None: mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) return memory.permute(1, 2, 0).view(bs, c, h, w) # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map. class TransformerEncoderPixelDecoder(BasePixelDecoder): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, transformer_pre_norm: bool, conv_dim: int, mask_dim: int, mask_on: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers transformer_pre_norm: whether to use pre-layernorm or not conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on) input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] in_channels = feature_channels[len(self.in_features) - 1] self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) self.transformer = TransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, normalize_before=transformer_pre_norm, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # update layer use_bias = norm == "" output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) delattr(self, "layer_{}".format(len(self.in_features))) self.add_module("layer_{}".format(len(self.in_features)), output_conv) self.output_convs[0] = output_conv @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] ret = super().from_config(cfg, input_shape) ret["transformer_dropout"] = dec_cfg['DROPOUT'] ret["transformer_nheads"] = dec_cfg['NHEADS'] ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config ret["transformer_pre_norm"] = dec_cfg['PRE_NORM'] ret['mask_on'] = cfg['MODEL']['DECODER']['MASK'] return ret def forward_features(self, features): multi_scale_features = [] num_cur_levels = 0 # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: transformer = self.input_proj(x) pos = self.pe_layer(x) transformer = self.transformer(transformer, None, pos) y = output_conv(transformer) # save intermediate feature as input to Transformer decoder transformer_encoder_features = transformer else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(y) num_cur_levels += 1 mask_features = self.mask_features(y) if self.mask_on else None return mask_features, transformer_encoder_features, multi_scale_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") return self.forward_features(features) @register_encoder def get_transformer_encoder_fpn(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ model = TransformerEncoderPixelDecoder(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model ================================================ FILE: llava/model/semsam/body/openseed_head.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li and Hao Zhang. # ------------------------------------------------------------------------------ import logging from typing import Callable, Dict, List, Optional, Tuple, Union from torch import nn from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .registry import register_body from .encoder import build_encoder from .decoder import build_decoder from ..utils import configurable class MaskDINOHead(nn.Module): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, transformer_predictor: nn.Module, ): """ Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. transformer_predictor: the transformer decoder that makes prediction transformer_in_feature: input feature name to the transformer_predictor """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = 4 self.loss_weight = loss_weight self.pixel_decoder = pixel_decoder self.predictor = transformer_predictor self.num_classes = num_classes @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): enc_cfg = cfg['MODEL']['ENCODER'] dec_cfg = cfg['MODEL']['DECODER'] transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] return { "input_shape": { k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] }, "ignore_value": enc_cfg['IGNORE_VALUE'], "num_classes": enc_cfg.get('NUM_CLASSES', None), "pixel_decoder": build_encoder(cfg, input_shape), "loss_weight": enc_cfg['LOSS_WEIGHT'], "transformer_predictor": build_decoder( cfg, transformer_predictor_in_channels, lang_encoder, mask_classification=True, extra=extra, ), } def forward(self, features, mask=None, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): return self.layers(features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) def layers(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) if task == 'teacher': predictions = self.predictor.forward_teacher(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) else: predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) return predictions @register_body def get_maskdino_head(cfg, input_shape, lang_encoder, extra): return MaskDINOHead(cfg, input_shape, lang_encoder, extra) ================================================ FILE: llava/model/semsam/body/registry.py ================================================ _model_entrypoints = {} def register_body(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/body/transformer_blocks.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py """ Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ import copy from typing import List, Optional import torch import torch.nn.functional as F from torch import Tensor, nn class Transformer(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, ) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) if mask is not None: mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder( tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed ) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos ) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") ================================================ FILE: llava/model/semsam/language/LangEncoder/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from .build import build_lang_encoder from .build import build_tokenizer from .transformer import * ================================================ FILE: llava/model/semsam/language/LangEncoder/build.py ================================================ import os from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers import AutoTokenizer from .registry import lang_encoders from .registry import is_lang_encoder def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): model_name = config_encoder['NAME'] if not is_lang_encoder(model_name): raise ValueError(f'Unkown model: {model_name}') return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) def build_tokenizer(config_encoder): tokenizer = None os.environ['TOKENIZERS_PARALLELISM'] = 'true' if config_encoder['TOKENIZER'] == 'clip': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) elif config_encoder['TOKENIZER'] == 'clip-fast': pretrained_tokenizer = config_encoder.get( 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' ) tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) else: tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) return tokenizer ================================================ FILE: llava/model/semsam/language/LangEncoder/registry.py ================================================ _lang_encoders = {} def register_lang_encoder(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _lang_encoders[model_name] = fn return fn def lang_encoders(model_name): return _lang_encoders[model_name] def is_lang_encoder(model_name): return model_name in _lang_encoders ================================================ FILE: llava/model/semsam/language/LangEncoder/transformer.py ================================================ from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .registry import register_lang_encoder from detectron2.utils.comm import is_main_process from utils.model import register_norm_module logger = logging.getLogger(__name__) @register_norm_module class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn( x, x, x, key_padding_mask=key_padding_mask, need_weights=False, attn_mask=self.attn_mask )[0] def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0, autogressive: bool =True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers self.autogressive = autogressive attn_mask = self.build_attention_mask() if autogressive else None dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.ModuleList( [ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) @property def dim_out(self): return self.width def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): if is_main_process(): logger.info('=> init weight of Linear/Conv2d from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: if is_main_process(): logger.info('=> init bias of Linear/Conv2d to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x pretrained_dict = { stripped_key(k): v for k, v in pretrained_dict.items() if stripped_key(k) in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if verbose: logger.info(f'=> init {k} from {pretrained}') if 'positional_embedding' in k and v.size() != model_dict[k].size(): positional_embedding_pretrained = v positional_embedding_current = model_dict[k] L1, nH1 = positional_embedding_pretrained.size() L2, nH2 = positional_embedding_current.size() if nH1 != nH2: logger.info(f"Error in loading {k}, passing") else: if L1 != L2: logger.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) posemb = positional_embedding_pretrained.float() posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) v = posemb_grid need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, input_ids, attention_mask=None): key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None # key_padding_mask = (input_ids == 0) if not self.autogressive else None x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND for block in self.resblocks: x = block(x, key_padding_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) return {'last_hidden_state': x} @register_lang_encoder def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): transformer = Transformer( context_length=config_encoder['CONTEXT_LENGTH'], vocab_size=tokenizer.vocab_size, width=config_encoder['WIDTH'], layers=config_encoder['LAYERS'], heads=config_encoder['HEADS'], autogressive=config_encoder.get('AUTOGRESSIVE', True) ) if config_encoder.get('LOAD_PRETRAINED', False): transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*'])) return transformer ================================================ FILE: llava/model/semsam/language/__init__.py ================================================ # from .vlpencoder import * # from .encoder import * # from .fixencoder import * # from .loss import * # from .modeling_llama_os import LlamaForCausalLM # # from .modeling_llama_os_lora import LlamaForCausalLMLora # from .llama_encoder import * # from .build import build_language_encoder ================================================ FILE: llava/model/semsam/language/build.py ================================================ from .registry import model_entrypoints from .registry import is_model def build_language_encoder(config, **kwargs): model_name = config['MODEL']['TEXT']['ARCH'] if model_name=='noencoder': return None if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) ================================================ FILE: llava/model/semsam/language/encoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } # @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) # @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/semsam/language/fixencoder.py ================================================ import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, ): super().__init__() self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) @classmethod def from_config(cls, cfg): # build up text encoder tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, } @torch.no_grad() def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) @torch.no_grad() def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x @torch.no_grad() # FIXME hack to freeze all parameters def compute_similarity(self, v_emb, name='default'): v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/semsam/language/llama_encoder.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence import torch import transformers from torch.utils.data import Dataset from transformers import Trainer from llava import conversation as conversation_lib from PIL import Image import torch.nn as nn # from openseed.BaseModel import BaseModel # from openseed import build_model import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates from openseed.language import LlamaForCausalLM # TODO: import and use code from ../data/dataset.py IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" DEFAULT_OBJECT_START_TOKEN = "" DEFAULT_OBJECT_END_TOKEN = "" ENC_LENS=[140*64,140*16,140*4,140] ENC_ID=-1 @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") freeze_backbone: bool = field(default=False) dbg: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) config_file: Optional[str] = field(default="") os_weights:Optional[str]=field(default="") vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) pretrain_obj_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_token_len: int = 0 image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) # dbg: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], multimodal_cfg: dict, cur_token_len: int, ) -> Dict: is_multimodal = multimodal_cfg['is_multimodal'] # image_token_len = multimodal_cfg['image_token_len'] image_token_len = cur_token_len if not is_multimodal: return sources for source in sources: for sentence in source: replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if multimodal_cfg['use_im_start_end']: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): super(SupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...") sources = [example["conversations"] for example in list_data_dict] data_dict = preprocess(sources, tokenizer) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] def __len__(self): return len(self.input_ids) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict(input_ids=self.input_ids[i], labels=self.labels[i]) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, multimodal_cfg: dict): super(LazySupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.multimodal_cfg = multimodal_cfg def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME try: if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.multimodal_cfg['image_folder'] processor = self.multimodal_cfg['image_processor'] image = Image.open(os.path.join(image_folder, image_file)) if self.multimodal_cfg['image_aspect_ratio'] == 'keep': max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 1333, 800 shortest_edge = int(min(max_len / aspect_ratio, min_len)) try: image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge}, do_rescale=False, do_normalize=False)['pixel_values'][0] except Exception: return self.__getitem__(i + 1) else: # try: image = processor.preprocess(image, return_tensors='pt', do_rescale=False, do_normalize=False, do_center_crop=False,size=(640,64*14))[ 'pixel_values'][0] # except Exception: # return self.__getitem__(i+1) # FIXME: cur_token_len should be num_queries when using det # cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size cur_token_len = ENC_LENS[ENC_ID] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.multimodal_cfg, cur_token_len) else: sources = copy.deepcopy([e["conversations"] for e in sources]) except Exception: return self.__getitem__(i + 1) data_dict = preprocess( sources, self.tokenizer) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.multimodal_cfg['is_multimodal']: # image does not exist in the data, but the model is multimodal crop_size = self.multimodal_cfg['image_processor'].crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" dataset_cls = (LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset) train_dataset = dataset_cls(tokenizer=tokenizer, data_path=data_args.data_path, multimodal_cfg=dict( is_multimodal=data_args.is_multimodal, image_token_len=data_args.image_token_len, image_folder=data_args.image_folder, image_aspect_ratio=data_args.image_aspect_ratio, use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False), image_processor=getattr(data_args, 'image_processor', None))) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) # from detectron2.config import get_cfg, CfgNode from detectron2.config import LazyConfig, instantiate # from detectron2.utils.logger import setup_logger # from detectron2.engine import default_setup def setup(config_file): """ Create configs and perform basic setups. """ cfg = LazyConfig.load(config_file) # cfg = LazyConfig.apply_overrides(cfg, args.opts) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg @register_model def get_language_model(cfg, **kwargs): llama_cfg = cfg['MODEL']['LLAMA'] if llama_cfg['load_fp16']: return LlamaForCausalLM.from_pretrained( llama_cfg['model_name_or_path'], cache_dir=llama_cfg['cache_dir'], torch_dtype=torch.float16 ) else: return LlamaForCausalLM.from_pretrained( llama_cfg['model_name_or_path'], cache_dir=llama_cfg['cache_dir'], # torch_dtype=torch.float16 ) def train(): parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if model_args.dbg: training_args._n_gpu = 1 ENC_ID=model_args.mm_vision_select_layer model = LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) cfg = setup(model_args.config_file) if model_args.freeze_backbone: model.model.requires_grad_(False) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), tokenizer=tokenizer, model=model, ) if "llama" in model_args.model_name_or_path: tokenizer.add_special_tokens({ "eos_token": DEFAULT_EOS_TOKEN, "bos_token": DEFAULT_BOS_TOKEN, "unk_token": DEFAULT_UNK_TOKEN, }) if model_args.vision_tower is not None: model.config.mm_vision_tower = model_args.vision_tower from transformers import CLIPVisionModel from llava.train.image_processing_gptv import CLIPImageProcessor dtype = torch.float32 if training_args.fp16: dtype = torch.float16 if training_args.bf16: dtype = torch.bfloat16 openseed_vision = BaseModel(cfg, build_model(cfg)).cuda() # if not model_args.dbg: checkpoint = torch.load(model_args.os_weights, map_location='cpu') model_dict = openseed_vision.state_dict() pretrained_dict = {"model."+k: v for k, v in checkpoint.items() if "model."+k in model_dict} model_dict.update(pretrained_dict) openseed_vision.load_state_dict(model_dict) # openseed_vision.stat if not hasattr(model.model, 'vision_tower'): vision_tower = CLIPVisionModel.from_pretrained(model_args.vision_tower) else: vision_tower = model.model.vision_tower[0] image_processor = CLIPImageProcessor.from_pretrained(model_args.vision_tower) image_processor.size['shortest_edge'] = 800 vision_config = vision_tower.config vision_tower=openseed_vision vision_tower.config=vision_config # vision_tower.num_queries=300 # vision_tower.idx=model_args.enc_idx vision_tower.idx=ENC_ID vision_tower.num_enc_tokens=ENC_LENS[vision_tower.idx] vision_tower.dim_queries=256 # num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 num_patches=vision_tower.num_enc_tokens data_args.image_token_len = num_patches data_args.image_processor = image_processor data_args.is_multimodal = True vision_tower.requires_grad_(False) # model.model.vision_tower = vision_tower # HACK: for FSDP vision_tower.to(device=training_args.device) model.model.vision_tower = [vision_tower] model.config.use_mm_proj = True model.config.mm_hidden_size = vision_config.hidden_size=vision_tower.dim_queries model.config.mm_vision_select_layer = model_args.mm_vision_select_layer if not hasattr(model.model, 'mm_projector') or model.model.mm_projector.weight.shape[1]!=vision_config.hidden_size: mm_projector = nn.Linear(vision_config.hidden_size, model.config.hidden_size) model.model.mm_projector = mm_projector else: mm_projector = model.model.mm_projector if not hasattr(model.model, 'obj_projector'): obj_projector = nn.Linear(vision_config.hidden_size+4, model.config.hidden_size) model.model.obj_projector = obj_projector else: obj_projector = model.model.obj_projector if model_args.pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) if model_args.pretrain_obj_mlp_adapter is not None: obj_projector_weights = torch.load(model_args.pretrain_obj_mlp_adapter, map_location='cpu') obj_projector.load_state_dict({k.split('.')[-1]: v for k, v in obj_projector_weights.items()}) model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in mm_projector.parameters(): p.requires_grad = True for p in obj_projector.parameters(): p.requires_grad = True model.config.mm_use_im_start_end = model_args.mm_use_im_start_end data_args.mm_use_im_start_end = model_args.mm_use_im_start_end tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) vision_config.use_im_start_end = model_args.mm_use_im_start_end if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_OBJECT_START_TOKEN,DEFAULT_OBJECT_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: model.model.orig_embeds_params = [model.get_input_embeddings().weight.data.clone().to(device=training_args.device)] for p in model.get_input_embeddings().parameters(): p.requires_grad = True for p in model.get_output_embeddings().parameters(): p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert input_embeddings.shape == embed_tokens_weight.shape assert num_new_tokens == 2 input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/model/semsam/language/loss.py ================================================ import pickle from distutils import log import torch import torch.nn.functional as F import torch.distributed as dist from einops import rearrange, repeat from timm.loss import SoftTargetCrossEntropy soft_cross_entropy = SoftTargetCrossEntropy() def is_dist_initialized(): return torch.distributed.is_initialized() def get_world_size(): if is_dist_initialized(): return torch.distributed.get_world_size() return 1 def get_rank(): if is_dist_initialized(): return dist.get_rank() return 0 def all_gather_grad(x): if get_world_size() > 1: all_x = [torch.zeros_like(x) for _ in range(get_world_size())] torch.distributed.all_gather(all_x, x) all_x[torch.distributed.get_rank()] = x x = torch.cat(all_x, dim=0) return x def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1): """ Args: image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256 text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256 Returns: """ # [B, L1, C], L1 = 1 # image_feat = F.normalize(image_feat, dim=-1) # [B, L2, C] # text_feat = F.normalize(text_feat, dim=-1) # HACK: normalize outside # [B, L1, L2] dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l') # [B, L2, L1] dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l') batch = image_feat.shape[0] img_len = image_feat.shape[1] text_len = text_feat.shape[1] # [B, L1, L2] pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2') # [B, L2, L1] pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1') image_x = rearrange(image_feat, 'b l c -> (b l) c') text_x = rearrange(text_feat, 'b l c -> (b l) c') logits_per_img = image_x @ all_gather_grad(text_x).t() logits_per_text = text_x @ all_gather_grad(image_x).t() # get label globally # [B, L1, B, L2, W] labels_per_img = F.one_hot( torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(), num_classes=get_world_size()).to(image_x.dtype) labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat( torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1') # [BxL1, WxBxL2] labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)') # [B, L2, B, L1, W] labels_per_text = F.one_hot( torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(), num_classes=get_world_size()).to(text_x.dtype) labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat( torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1') # [BxL2, WxBxL1] labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)') logit_scale = temperature.exp().clamp(max=100) loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img) loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text) loss = 0.5 * (loss_img + loss_text) return loss def vl_contrastive_loss(image_feat, text_feat, temperature=1): # if image_id or text_id is None, it should be None across all GPUs # image_feat = F.normalize(image_feat, dim=1) # text_feat = F.normalize(text_feat, dim=1) # handle normalization outside # add the following 4 lines image_feat = all_gather_grad(image_feat) text_feat = all_gather_grad(text_feat) logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs def all_gather_pickle(data, device): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device) # obtain Tensor size of each rank local_size = torch.LongTensor([tensor.numel()]).cuda() size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda()) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def all_gather_arbitary_tensor(tensor): if get_world_size() > 1: device = tensor.device tensor_batch = all_gather_pickle(tensor.cpu(), device) tensor_batch = [x.to(device) for x in tensor_batch] tensor_batch[torch.distributed.get_rank()] = tensor tensor_batch = torch.cat(tensor_batch, dim=0) else: tensor_batch = tensor return tensor_batch def ql_contrastive_loss(image_feat, text_feat, temperature=1): # add the following 4 lines image_feat = all_gather_arbitary_tensor(image_feat) text_feat = all_gather_arbitary_tensor(text_feat) logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs def vl_similarity(image_feat, text_feat, temperature=1): # Only support single GPU for now. logits = torch.matmul(image_feat, text_feat.t()) logits = temperature.exp().clamp(max=100) * logits return logits def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1): # add the following 4 lines image_feat = all_gather_arbitary_tensor(image_feat) text_feat = all_gather_arbitary_tensor(text_feat) text_hash_batch = all_gather_pickle(text_hash, text_feat.device) text_hash_all = torch.cat(text_hash_batch) text_hash_all_unique = torch.unique(text_hash_all).tolist() gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device) text_hash_all = text_hash_all.tolist() text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique]) for idx, txt in enumerate(text_hash_all): gt[idx][text_hash_all_unique.index(txt)] = 1 logits = torch.matmul(image_feat, text_feat_unique.t()) logits = logits*temperature.exp().clamp(max=100) loss_img = soft_cross_entropy(logits, gt) loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True)) loss = 0.7 * loss_img + 0.3 * loss_text return loss def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training): # add the following 4 lines image_feat = all_gather_grad(image_feat_inp.contiguous()) text_feat = all_gather_grad(text_feat_inp.contiguous()) image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7) text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7) temperature = lang_enc.logit_scale logits = torch.matmul(image_feat, text_feat.t()) logit_scale = temperature.exp().clamp(max=100) gt = torch.arange(logits.shape[0], device=logits.device) loss1 = F.cross_entropy(logit_scale * logits, gt) loss2 = F.cross_entropy(logit_scale * logits.t(), gt) return (loss1 + loss2) / 2 # scale it up by the number of GPUs ================================================ FILE: llava/model/semsam/language/misc.py ================================================ import random import torch import nltk nltk.data.path.append('/mnt/data/nltk_data') import numpy as np from utils.constants import IMAGENET_DEFAULT_TEMPLATES def vl_similarity(image_feat, text_feat, temperature=1): # Only support single GPU for now. logits = torch.matmul(image_feat, text_feat.t()) logits = temperature.exp().clamp(max=100) * logits return logits def get_tag(tokenized, tags): if not isinstance(tags, (list, tuple)): tags = [tags] ret = [] for (word, pos) in nltk.pos_tag(tokenized): for tag in tags: if pos == tag: ret.append(word) return ret def get_noun_phrase(tokenized): # Taken from Su Nam Kim Paper... grammar = r""" NBAR: {*} # Nouns and Adjectives, terminated with Nouns NP: {} {} # Above, connected with in/of/etc... """ chunker = nltk.RegexpParser(grammar) chunked = chunker.parse(nltk.pos_tag(tokenized)) continuous_chunk = [] current_chunk = [] for subtree in chunked: if isinstance(subtree, nltk.Tree): current_chunk.append(' '.join([token for token, pos in subtree.leaves()])) elif current_chunk: named_entity = ' '.join(current_chunk) if named_entity not in continuous_chunk: continuous_chunk.append(named_entity) current_chunk = [] else: continue return continuous_chunk def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True): tokenized = nltk.word_tokenize(text) if random.random() >= phrase_prob: nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP']) else: nouns = get_noun_phrase(tokenized) prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns] if append_text: prompt_texts += [text] nouns += [text] return prompt_texts, nouns ================================================ FILE: llava/model/semsam/language/modeling_llama_os.py ================================================ # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.models.llama.configuration_llama import LlamaConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): cos = cos[..., offset : q.shape[-2] + offset, :] sin = sin[..., offset : q.shape[-2] + offset, :] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LlamaMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, hidden_size: int, num_heads: int, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {num_heads})." ) self.q_proj = nn.Linear( hidden_size, num_heads * self.head_dim, bias=False, ) self.k_proj = nn.Linear( hidden_size, num_heads * self.head_dim, bias=False, ) self.v_proj = nn.Linear( hidden_size, num_heads * self.head_dim, bias=False, ) self.o_proj = nn.Linear( num_heads * self.head_dim, hidden_size, bias=False, ) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] offset = 0 if past_key_value is not None: offset = past_key_value[0].shape[-2] kv_seq_len += offset cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset) # [bsz, nh, t, hd] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs LLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlamaModel): module.gradient_checkpointing = value LLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ def __init__(self, config: LlamaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.embed_tokens_out=None # Initialize weights and apply final processing self.post_init() if hasattr(config, "mm_vision_tower"): from transformers import CLIPVisionModel self.vision_tower = [None] if hasattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_input_embeddings(self): return self.embed_tokens def get_output_embeddings(self): return self.embed_tokens_out def set_input_embeddings(self, value): self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask # def find_pattern(self,src,pattern): def find_pattern_list(self, pattern, src): assert len(pattern) <= len(src) i = len(pattern)-1 while True: match = True for j in range(len(pattern)): if int(src[i - j]) != pattern[len(pattern) - 1 - j]: match = False break if match: return i i += 1 if i >= len(src) - 1: return -1 def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, im_feats=None, obj_feats=None,fp16=True,mm_projector=None,obj_projector=None,obj_projector_out=None,obj_num=True,question_ref_queries=None,**kwargs ) -> Union[Tuple, BaseModelOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length # HACK: replace back original embeddings for pretraining orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: orig_embeds_params_in = orig_embeds_params[0] orig_embeds_params_out = orig_embeds_params[1] st=self.tokenizer.encode("")[1] with torch.no_grad(): self.get_input_embeddings().weight[:st] = orig_embeds_params_in[:st].data # if self.tokenizer.decode([len(self.tokenizer)-1])=='': self.get_output_embeddings().weight[:st] = orig_embeds_params_out[:st].data # if fp16: # self.get_input_embeddings().weight=self.get_input_embeddings().weight.to(torch.float16) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if (input_ids.shape[1] != 1 or self.training) and im_feats is not None: if mm_projector is None: mm_projector=self.mm_projector if obj_projector is None: obj_projector=self.obj_projector if fp16: image_features = mm_projector(im_feats.to(torch.float16)) obj_features=[] for feat in obj_feats: if feat is not None: obj_features.append(obj_projector(feat.to(torch.float16))) else: obj_features.append(None) if question_ref_queries is not None: for i,q in enumerate(question_ref_queries): if q is not None: question_ref_queries[i]=obj_projector(q.to(torch.float16)) # question_ref_queries=[obj_projector(q.to(torch.float16)) for q in question_ref_queries if q is not None] else: image_features = mm_projector(im_feats) obj_features = [] for feat in obj_feats: if feat is not None: obj_features.append(obj_projector(feat)) else: obj_features.append(None) # import pdb;pdb.set_trace() if question_ref_queries is not None: for i,q in enumerate(question_ref_queries): if q is not None: question_ref_queries[i]=obj_projector(q) # question_ref_queries=[obj_projector(q) for q in question_ref_queries if q is not None] new_input_embeds = [] cur_image_idx = 0 for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): cur_image_features = image_features[cur_image_idx] if (cur_input_ids == self.im_start_token).sum() != (cur_input_ids == self.im_end_token).sum(): raise ValueError("The number of im_start_token and im_end_token should be the same") image_start_tokens = torch.where(cur_input_ids == self.im_start_token)[0] assert len(image_start_tokens)==1 image_start_token_pos=image_start_tokens[0] # for image_start_token_pos in image_start_tokens: #currently only one image cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != self.im_end_token: raise ValueError("Seems that the image is cut.") cur_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) ############## OBJ cur_obj_features = obj_features[cur_image_idx] if cur_obj_features is not None: if (cur_input_ids == self.obj_start_token).sum() != (cur_input_ids == self.obj_end_token).sum(): raise ValueError("The number of obj_start_token and obj_end_token should be the same") obj_start_tokens = torch.where(cur_input_ids == self.obj_start_token)[0] assert len(obj_start_tokens)==1 obj_start_token_pos=obj_start_tokens[0] obj_end_tokens = torch.where(cur_input_ids == self.obj_end_token)[0] assert len(obj_end_tokens) == 1 obj_end_token_pos = obj_end_tokens[0] # for image_start_token_pos in image_start_tokens: #currently only one image num_patches = cur_obj_features.shape[0] if obj_num: starts=[] for i_obj in range(num_patches): mark=self.tokenizer.encode(f"{i_obj}.")[1:] start_i=self.find_pattern_list(mark,cur_input_ids)+1 assert cur_input_ids[start_i]==self.tokenizer.encode("")[1] if start_i!=-1 and start_i>obj_start_token_pos and start_iobj_end_tokens[0]] if len(obj_patch_tokens)>0: cur_new_input_embeds[obj_patch_tokens]=question_ref_queries[cur_image_idx].to(cur_input_embeds.dtype) else: cur_new_input_embeds=cur_input_embeds cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) inputs_embeds = torch.stack(new_input_embeds, dim=0) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class LlamaForCausalLM(LlamaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, im_feats=None, obj_feats=None,fp16=True,tokenizer=None,training=True,return_hidden=False,reduce_loss=True, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict self.model.tokenizer=tokenizer if not training: self.model.training=False self.model.embed_tokens_out=self.get_output_embeddings() # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, images=images, im_feats=im_feats, obj_feats=obj_feats,fp16=fp16,**kwargs ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None loss_ls=[] if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() if not reduce_loss: shift_logits = shift_logits.view(shift_logits.shape[0], -1,self.config.vocab_size) shift_labels = shift_labels.view(shift_logits.shape[0],-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) for shift_logits_,shift_labels_ in zip(shift_logits, shift_labels): loss_ls.append(loss_fct(shift_logits_, shift_labels_)) loss=loss_ls else: shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss=loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.last_hidden_state, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, LLAMA_START_DOCSTRING, ) class LlamaForSequenceClassification(LlamaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = LlamaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) ================================================ FILE: llava/model/semsam/language/registry.py ================================================ _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints ================================================ FILE: llava/model/semsam/language/vlpencoder.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch from torch import nn from torch.nn import functional as F from timm.models.layers import trunc_normal_ from .registry import register_model from ..utils import configurable from .LangEncoder import build_tokenizer, build_lang_encoder from utils.prompt_engineering import prompt_engineering, get_prompt_templates class LanguageEncoder(nn.Module): @configurable def __init__( self, tokenizer, tokenizer_type, lang_encoder, lang_projection, max_token_num, queue_operator, ): super().__init__() # seg self.tokenizer = tokenizer self.tokenizer_type = tokenizer_type self.lang_encoder = lang_encoder self.lang_proj = lang_projection self.max_token_num = max_token_num self.logit_scale = nn.Parameter(torch.ones([])) # captioning & retrieval for key, value in queue_operator.items(): self.register_buffer(key, value) @classmethod def from_config(cls, cfg): # build up text encoder for seg tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] dim_lang = cfg['MODEL']['TEXT']['WIDTH'] dim_projection = cfg['MODEL']['DIM_PROJ'] lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) trunc_normal_(lang_projection, std=.02) # tested not working better queue_operator = {} return { "tokenizer": tokenizer, "tokenizer_type": tokenizer_type, "lang_encoder": lang_encoder, "lang_projection": lang_projection, "max_token_num": max_token_num, "queue_operator": queue_operator, } def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): if not is_eval: if prompt: # randomly sample one template arbitary_concepts = [ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ for label in range(len(class_names)) ] if add_bgd: arbitary_concepts.append("A background in coco.") else: arbitary_concepts = class_names input_ids = [] attention_masks = [] for txt in arbitary_concepts: tokens = self.tokenizer( txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens['input_ids'].squeeze_() tokens['attention_mask'].squeeze_() input_ids.append(tokens['input_ids']) attention_masks.append(tokens['attention_mask']) arbitary_tokens = torch.stack(input_ids) arbitary_attention_masks = torch.stack(attention_masks) text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) setattr(self, '{}_text_embeddings'.format(name), text_emb) else: with torch.no_grad(): def extract_mean_emb(txts): tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() return clss_embedding templates = get_prompt_templates() clss_embeddings = [] if prompt: for clss in class_names: txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] clss_embeddings.append(extract_mean_emb(txts)) else: clss_embeddings.append(extract_mean_emb(class_names)) if add_bgd: txts = ["A background in coco."] clss_embeddings.append(extract_mean_emb(txts)) text_emb = torch.stack(clss_embeddings, dim=0) setattr(self, '{}_text_embeddings'.format(name), text_emb) def get_text_token_embeddings(self, txts, name='default', token=False, norm=False): if not token: tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' ) tokens = {key: value.cuda() for key, value in tokens.items()} else: tokens = txts token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm) ret = {"tokens": tokens, "token_emb": token_emb, "class_emb": class_emb,} setattr(self, '{}_token_embeddings'.format(name), ret) return ret def forward_language(self, texts, norm=True): x = self.lang_encoder(*texts) x = x['last_hidden_state'] if self.tokenizer_type == 'clip': x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.lang_proj if norm: x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) return x def forward_language_token(self, texts, norm=False): x = self.lang_encoder(*texts) token_x = x['last_hidden_state'] if self.tokenizer_type == 'clip': class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)] else: class_x = token_x[:, 0] class_x = class_x @ self.lang_proj token_x = token_x @ self.lang_proj if norm: class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7) token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7) return token_x, class_x def compute_similarity(self, v_emb, name='default', fake=False): if fake: return None v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) t_emb = getattr(self, '{}_text_embeddings'.format(name)) output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) return output @register_model def get_language_model(cfg, **kwargs): return LanguageEncoder(cfg) ================================================ FILE: llava/model/semsam/modules/__init__.py ================================================ from .point_features import * from .position_encoding import * from .postprocessing import * from .attention import * from .matcher import * from .criterion_id_llm import * from .hooks import HookBase ================================================ FILE: llava/model/semsam/modules/attention.py ================================================ import warnings from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter from torch.overrides import has_torch_function, handle_torch_function from torch.nn.functional import pad, linear, softmax, dropout def multi_head_attention_forward( query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None, k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, tens_ops, query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 if not use_separate_proj_weight: if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)): # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) if key is None: assert value is None k = None v = None else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] k, v = linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] k = linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] v = linear(value, _w, _b) else: q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) len1, len2 = q_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == query.size(-1) k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) len1, len2 = k_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == key.size(-1) v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) len1, len2 = v_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == value.size(-1) if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) v = linear(value, v_proj_weight_non_opt, in_proj_bias) q = q * scaling if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert bias_k is None assert bias_v is None q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: # assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if add_zero_attn: src_len += 1 k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1), float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(nn.Linear): bias: Tensor # type: ignore def __init__(self, in_features: int, out_features: int) -> None: super().__init__(in_features, out_features, bias=True) # type: ignore class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = _LinearWithBias(embed_dim, embed_dim) if add_bias_kv: self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(MultiheadAttention, self).__setstate__(state) def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not self._qkv_same_embed_dim: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) else: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask) ================================================ FILE: llava/model/semsam/modules/criterion_id_llm.py ================================================ # ------------------------------------------------------------------------ # Copyright (c) IDEA, Inc. and its affiliates. # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ """ MaskFormer criterion. """ import logging import torch import torch.nn.functional as F from torch import nn from detectron2.utils.comm import get_world_size from detectron2.projects.point_rend.point_features import ( get_uncertain_point_coords_with_randomness, point_sample, ) from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list from ..utils import box_ops from utils.utils import slprint def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss # return loss.mean(1).sum() / num_boxes def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) # only match the lowest loss # loss = loss.view(-1, 3) # loss = loss.min(1)[0] # return loss.sum() / num_masks return loss def iou_score_loss(inputs, targets): ce_loss = F.mse_loss(inputs, targets, reduction="none") return ce_loss dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.mean(1) # loss = loss.view(-1, 3).min(1)[0] # return loss.sum() / num_masks return loss sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) class SetCriterionLLM(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_points, oversample_ratio, importance_sample_ratio, dn="no", dn_losses=[], panoptic_on=False, semantic_ce_loss=False, num_mask_tokens=3, iou_loss=True): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.num_classes_part = -1 self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses self.dn = dn self.dn_losses = dn_losses empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = num_points self.oversample_ratio = oversample_ratio self.importance_sample_ratio = importance_sample_ratio self.focal_alpha = 0.25 self.panoptic_on = panoptic_on self.semantic_ce_loss = semantic_ce_loss self.num_mask_tokens = num_mask_tokens self.index = None self.iou_loss = iou_loss self.prediction_switch = None self.index_switch = {'part': torch.arange(0, self.num_mask_tokens - 1).cuda(), 'whole': torch.arange(self.num_mask_tokens - 1, self.num_mask_tokens).cuda(), 'all': torch.arange(0, self.num_mask_tokens).cuda(), } # self.dbg_f=open("/comp_robot/zhanghao/model/idino_llama_coco/dbg","a") self.keys="loss_bbox_0, loss_giou_0, loss_mask_bce_0, loss_mask_dice_0, iou_score_loss_0, loss_bbox_1, loss_giou_1, loss_mask_bce_1, loss_mask_dice_1, iou_score_loss_1, loss_bbox_2, loss_giou_2, loss_mask_bce_2, loss_mask_dice_2, iou_score_loss_2, loss_bbox_3, loss_giou_3, loss_mask_bce_3, loss_mask_dice_3, iou_score_loss_3, loss_bbox_4, loss_giou_4, loss_mask_bce_4, loss_mask_dice_4, iou_score_loss_4, loss_bbox_5, loss_giou_5, loss_mask_bce_5, loss_mask_dice_5, iou_score_loss_5, loss_bbox_6, loss_giou_6, loss_mask_bce_6, loss_mask_dice_6, iou_score_loss_6, loss_bbox_7, loss_giou_7, loss_mask_bce_7, loss_mask_dice_7, iou_score_loss_7, loss_bbox_8, loss_giou_8, loss_mask_bce_8, loss_mask_dice_8, iou_score_loss_8" self.keys=self.keys.split(", ") print("iou_loss is ", iou_loss) def loss_labels_ce(self, outputs, targets, indices, num_masks): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_ce": loss_ce} return losses def loss_labels(self, outputs, targets, indices, num_boxes, log=True, key='gt_whole_classes'): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ # assert 'pred_logits' in outputs if self.prediction_switch is None or 'whole' not in self.prediction_switch.keys(): if 'labels' in targets[0].keys() and targets[0]['labels'] is not None: key = 'labels' else: if not self.prediction_switch['whole']: return {"fake_no_loss_mask_cls_0": 0.0} elif key not in targets[0].keys(): # FIXME only consider batchsize=1 case assert len(targets) == 1 return {"loss_mask_cls_0": 0.0 * outputs['pred_logits'].sum()} src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t[key][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \ src_logits.shape[1] losses = {} loss_ce = loss_ce.sum(2) losses["loss_mask_cls_0"] = torch.gather(loss_ce.view(-1, 3), 1, self.index.unsqueeze(1)).mean().sum() / num_boxes # losses = {"loss_mask_cls_0": loss_ce} # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()} return losses def loss_labels_part(self, outputs, targets, indices, num_boxes, log=True, key='gt_part_classes'): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ # assert 'pred_logits_part' in outputs if not self.prediction_switch['part']: return {"fake_no_loss_mask_part_cls_0": 0.0} elif key not in targets[0].keys(): # FIXME only consider batchsize=1 case assert len(targets) == 1 # return {"loss_mask_whole_cls_0": 0.0*outputs['pred_logits_part'].sum()} return {"loss_mask_part_cls_0": 0.0 * outputs['pred_logits_part'].sum()} src_logits = outputs['pred_logits_part'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t[key][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes_part, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \ src_logits.shape[1] losses = {} loss_ce = loss_ce.sum(2) losses["loss_mask_part_cls_0"] = torch.gather(loss_ce.view(-1, 3), 1, self.index.unsqueeze(1)).mean().sum() / num_boxes # losses = {"loss_mask_part_cls_0": loss_ce} return losses def loss_boxes_o365(self, outputs, targets, indices, num_boxes, layer_id=None, extra=None): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert 'pred_boxes' in outputs if indices is None or len(targets) == 0: loss = outputs['pred_boxes'].sum() * 0.0 losses = {"loss_bbox_0": loss, "loss_giou_0": loss} return losses idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou_0'] = loss_giou.sum() / num_boxes return losses def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert 'pred_boxes' in outputs if 'boxes' not in targets[0].keys(): # FIXME only consider batchsize=1 case assert len(targets) == 1 return {"loss_bbox_0": 0.0 * outputs['pred_boxes'].sum(), "loss_giou_0": 0.0 * outputs['pred_boxes'].sum(), } assert self.index is not None idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) # print(src_boxes) # print(target_boxes) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} loss_bbox = loss_bbox.sum(1) # losses["loss_bbox_0"] = loss_bbox.sum() / num_boxes try: losses["loss_bbox_0"] = torch.gather(loss_bbox.view(-1, 3), 1, self.index.unsqueeze(1)).sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses["loss_giou_0"] = torch.gather(loss_giou.view(-1, 3), 1, self.index.unsqueeze(1)).sum() / num_boxes except: losses["loss_bbox_0"] = loss_bbox.sum()*0.0 losses["loss_giou_0"] = loss_bbox.sum()*0.0 print(loss_bbox.view(-1, 3)) print(self.index.unsqueeze(1)) # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()} return losses def loss_boxes_panoptic(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_labels = torch.cat([t['labels'][i] for t, (_, i) in zip(targets, indices)], dim=0) isthing = target_labels < 80 target_boxes = target_boxes[isthing] src_boxes = src_boxes[isthing] loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses["loss_bbox_0"] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses["loss_giou_0"] = loss_giou.sum() / num_boxes return losses def loss_masks(self, outputs, targets, indices, num_masks): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ assert "pred_masks" in outputs if 'masks' not in targets[0].keys(): # FIXME only consider batchsize=1 case assert len(targets) == 1 return {"loss_mask_bce_0": 0.0 * outputs['pred_masks'].sum(), "loss_mask_dice_0": 0.0 * outputs['pred_masks'].sum(), "iou_score_loss_0": 0.0 * outputs['pred_masks'].sum(), } src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] src_masks = src_masks[src_idx] masks = [t["masks"] for t in targets] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(src_masks) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates :) # N x 1 x H x W # import pdb;pdb.set_trace() src_masks = src_masks[:, None] target_masks = target_masks[:, None] with torch.no_grad(): # sample point_coords point_coords = get_uncertain_point_coords_with_randomness( src_masks.float(), lambda logits: calculate_uncertainty(logits.float()), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) # get gt labels point_labels = point_sample( target_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) point_logits = point_sample( src_masks.float(), point_coords.float(), align_corners=False, ).squeeze(1) losses = { "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), # "loss_mask_bce_0": sigmoid_ce_loss(point_logits, point_labels, num_masks), "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks), # "loss_mask_dice_0": dice_loss(point_logits, point_labels, num_masks), } mask_loss = losses["loss_mask_bce_0"] + losses["loss_mask_dice_0"] mask_loss, index = mask_loss.view(-1, 3).min(1) # slprint(index) # if len(targets)>1: # FIXME starting box index is the same # assert targets[0]['box_start'] == targets[1]['box_start'] bs = outputs["pred_masks"].shape[0] box_start = targets[0]['box_start'] # index.view(bs, -1)[:, box_start:] = 0 # all the box index is set to 0 if self.index is None: self.index = index else: index=self.index losses["loss_mask_bce_0"] = torch.gather(losses["loss_mask_bce_0"].view(-1, 3), 1, index.unsqueeze(1)).sum() / num_masks dice_loss = torch.gather(losses["loss_mask_dice_0"].view(-1, 3), 1, index.unsqueeze(1)) losses["loss_mask_dice_0"] = dice_loss.sum() / num_masks target_iou = 1 - dice_loss src_ious = outputs["pred_ious"] iou_idx = ([src_idx[0].view(bs, -1)[:, :src_ious.shape[1]].flatten(), src_idx[1].view(bs, -1)[:, :src_ious.shape[1]].flatten()]) # print("loss_masks1") # slprint(target_iou) # print("loss_masks2") # # slprint(src_ious) # print("loss_masks3") # # slprint(iou_idx) # print("loss_masks4") # # slprint(index.unsqueeze(1)) src_ious = src_ious[iou_idx] src_ious = torch.gather(src_ious, 1, index.unsqueeze(1)) # # if self.iou_loss: losses['iou_score_loss_0'] = iou_score_loss(src_ious, target_iou).sum() / num_masks # losses={k:losses[k].to(torch.bfloat16) for k in losses.keys()} del src_masks del target_masks return losses def loss_labels_o365(self, outputs, targets, indices, num_boxes, log=True, layer_id=None, extra=None): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert 'pred_logits' in outputs if indices is None or len(targets) == 0: loss_ce = outputs['pred_logits'].sum() * 0.0 losses = {"loss_mask_cls_0": loss_ce} return losses src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2).mean(1).sum() losses = {'loss_mask_cls_0': loss_ce} return losses def prep_for_dn(self, mask_dict): output_known_lbs_bboxes = mask_dict['output_known_lbs_bboxes'] known_indice = mask_dict['known_indice'] scalar, pad_size = mask_dict['scalar'], mask_dict['pad_size'] assert pad_size % scalar == 0 single_pad = pad_size // scalar num_tgt = known_indice.numel() return output_known_lbs_bboxes, num_tgt, single_pad, scalar def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_masks): loss_map = { 'labels': self.loss_labels_ce if self.semantic_ce_loss else self.loss_labels, 'labels_o365': self.loss_labels_o365, 'labels_part': self.loss_labels_part, 'masks': self.loss_masks, 'boxes': self.loss_boxes_panoptic if self.panoptic_on else self.loss_boxes, 'boxes_o365': self.loss_boxes_o365, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_masks) def forward(self, outputs, targets, mask_dict=None, task='sam', extra={}, return_idx=False): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} # Retrieve the matching between the outputs of the last layer and the targets # if self.dn is not "no" and mask_dict is not None: # output_known_lbs_bboxes,num_tgt,single_pad,scalar = self.prep_for_dn(mask_dict) assert len(targets)==1, "now only support one image training for interactive segmentation" prediction_switch = extra self.prediction_switch = prediction_switch exc_idx = [] key = 'pred_boxes' for i in range(len(targets)): if len(targets[i]['boxes']) > 0: if task=='det': tgt_idx = torch.arange(0, len(targets[i]['boxes'])).long().cuda() else: tgt_idx = torch.arange(0, len(targets[i]['boxes'])).long().cuda().repeat_interleave( self.num_mask_tokens) src_idx = torch.arange(0, outputs[key].shape[1]).long().cuda() # tgt_idx = t.flatten() # output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t # output_idx = output_idx.flatten() else: output_idx = tgt_idx = src_idx = torch.tensor([]).long().cuda() exc_idx.append((src_idx, tgt_idx)) indices = exc_idx # indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_masks = sum(len(t["boxes"]) for t in targets) num_masks = torch.as_tensor( [num_masks], dtype=torch.float, device=outputs[key].device ) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_masks) num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() # Compute all the requested losses losses = {} losses['num_masks'] = num_masks if 'masks' in 'masks': assert 'masks' in self.losses[0], "must calculate mask loss first for match" # slprint(outputs) # slprint(targets) # slprint(indices) for loss in self.losses: if task=='det': if loss=='labels_part': continue if loss=='labels': loss='labels_o365' if loss=='boxes': loss='boxes_o365' if loss == 'masks': l_dict = dict() l_dict['loss_mask_bce_0'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_dice_0'] = torch.as_tensor(0.).to('cuda') losses.update(l_dict) else: if 'labels' in loss: continue losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) index=self.index if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): # indices = self.matcher(aux_outputs, targets) for loss in self.losses: if task == 'det': if loss == 'labels_part': continue if loss == 'labels': loss = 'labels_o365' if loss == 'boxes': loss = 'boxes_o365' if loss=='masks': l_dict=dict() l_dict['loss_mask_bce_0'] = torch.as_tensor(0.).to('cuda') l_dict['loss_mask_dice_0'] = torch.as_tensor(0.).to('cuda') l_dict = {k.replace('_0', f"_{i + 1}"): v for k, v in l_dict.items()} losses.update(l_dict) else: if 'labels' in loss: continue l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) # l_dict = {k + f"_{i}": v for k, v in l_dict.items()} l_dict = {k.replace('_0', f"_{i + 1}"): v for k, v in l_dict.items()} losses.update(l_dict) # totoal_loss = torch.tensor(0.0).cuda() # for k,v in losses.items(): # totoal_loss += v # losses = dict() # losses['all'] = totoal_loss # assert "iou_score_loss_0" in losses, losses.keys() # self.dbg_f.write(", ".join(list(losses.keys()))+'\n') losses={k:losses[k] for k in losses.keys()} if return_idx: return losses,index else: return losses def __repr__(self): head = "Criterion " + self.__class__.__name__ body = [ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), "losses: {}".format(self.losses), "weight_dict: {}".format(self.weight_dict), "num_classes: {}".format(self.num_classes), "eos_coef: {}".format(self.eos_coef), "num_points: {}".format(self.num_points), "oversample_ratio: {}".format(self.oversample_ratio), "importance_sample_ratio: {}".format(self.importance_sample_ratio), ] _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/semsam/modules/hooks.py ================================================ import logging import numpy as np import time import weakref from typing import List, Mapping, Optional import torch from torch.nn.parallel import DataParallel, DistributedDataParallel import detectron2.utils.comm as comm from detectron2.utils.events import EventStorage, get_event_storage from detectron2.utils.logger import _log_api_usage class HookBase: """ Base class for hooks that can be registered with :class:`TrainerBase`. Each hook can implement 4 methods. The way they are called is demonstrated in the following snippet: :: hook.before_train() for iter in range(start_iter, max_iter): hook.before_step() trainer.run_step() hook.after_step() iter += 1 hook.after_train() Notes: 1. In the hook method, users can access ``self.trainer`` to access more properties about the context (e.g., model, current iteration, or config if using :class:`DefaultTrainer`). 2. A hook that does something in :meth:`before_step` can often be implemented equivalently in :meth:`after_step`. If the hook takes non-trivial time, it is strongly recommended to implement the hook in :meth:`after_step` instead of :meth:`before_step`. The convention is that :meth:`before_step` should only take negligible time. Following this convention will allow hooks that do care about the difference between :meth:`before_step` and :meth:`after_step` (e.g., timer) to function properly. """ trainer: "TrainerBase" = None """ A weak reference to the trainer object. Set by the trainer when the hook is registered. """ def before_train(self): """ Called before the first iteration. """ pass def after_train(self): """ Called after the last iteration. """ pass def before_step(self): """ Called before each iteration. """ pass def after_step(self): """ Called after each iteration. """ pass def state_dict(self): """ Hooks are stateless by default, but can be made checkpointable by implementing `state_dict` and `load_state_dict`. """ return {} # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import datetime import itertools import logging import math import operator import os import tempfile import time import warnings from collections import Counter import torch from fvcore.common.checkpoint import Checkpointer from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer from fvcore.common.param_scheduler import ParamScheduler from fvcore.common.timer import Timer from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats import detectron2.utils.comm as comm from detectron2.evaluation.testing import flatten_results_dict from detectron2.solver import LRMultiplier from detectron2.utils.events import EventStorage, EventWriter from detectron2.utils.file_io import PathManager # from .train_net_check import HookBase # __all__ = [ # "CallbackHook", # "IterationTimer", # "PeriodicWriter", # "PeriodicCheckpointer", # "BestCheckpointer", # "LRScheduler", # "AutogradProfiler", # "EvalHook", # "PreciseBN", # "TorchProfiler", # "TorchMemoryStats", # ] """ Implement some common hooks. """ class CallbackHook(HookBase): """ Create a hook using callback functions provided by the user. """ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): """ Each argument is a function that takes one argument: the trainer. """ self._before_train = before_train self._before_step = before_step self._after_step = after_step self._after_train = after_train def before_train(self): if self._before_train: self._before_train(self.trainer) def after_train(self): if self._after_train: self._after_train(self.trainer) # The functions may be closures that hold reference to the trainer # Therefore, delete them to avoid circular reference. del self._before_train, self._after_train del self._before_step, self._after_step def before_step(self): if self._before_step: self._before_step(self.trainer) def after_step(self): if self._after_step: self._after_step(self.trainer) class IterationTimer(HookBase): """ Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training. This hook uses the time between the call to its :meth:`before_step` and :meth:`after_step` methods. Under the convention that :meth:`before_step` of all hooks should only take negligible amount of time, the :class:`IterationTimer` hook should be placed at the beginning of the list of hooks to obtain accurate timing. """ def __init__(self, warmup_iter=3): """ Args: warmup_iter (int): the number of iterations at the beginning to exclude from timing. """ self._warmup_iter = warmup_iter self._step_timer = Timer() self._start_time = time.perf_counter() self._total_timer = Timer() def before_train(self): self._start_time = time.perf_counter() self._total_timer.reset() self._total_timer.pause() def after_train(self): logger = logging.getLogger(__name__) total_time = time.perf_counter() - self._start_time total_time_minus_hooks = self._total_timer.seconds() hook_time = total_time - total_time_minus_hooks num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter if num_iter > 0 and total_time_minus_hooks > 0: # Speed is meaningful only after warmup # NOTE this format is parsed by grep in some scripts logger.info( "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( num_iter, str(datetime.timedelta(seconds=int(total_time_minus_hooks))), total_time_minus_hooks / num_iter, ) ) logger.info( "Total training time: {} ({} on hooks)".format( str(datetime.timedelta(seconds=int(total_time))), str(datetime.timedelta(seconds=int(hook_time))), ) ) def before_step(self): self._step_timer.reset() self._total_timer.resume() def after_step(self): # +1 because we're in after_step, the current step is done # but not yet counted iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1 if iter_done >= self._warmup_iter: sec = self._step_timer.seconds() self.trainer.storage.put_scalars(time=sec) else: self._start_time = time.perf_counter() self._total_timer.reset() self._total_timer.pause() class PeriodicWriter(HookBase): """ Write events to EventStorage (by calling ``writer.write()``) periodically. It is executed every ``period`` iterations and after the last iteration. Note that ``period`` does not affect how data is smoothed by each writer. """ def __init__(self, writers, period=20): """ Args: writers (list[EventWriter]): a list of EventWriter objects period (int): """ self._writers = writers for w in writers: assert isinstance(w, EventWriter), w self._period = period def after_step(self): if (self.trainer.iter + 1) % self._period == 0 or ( self.trainer.iter == self.trainer.max_iter - 1 ): for writer in self._writers: writer.write() def after_train(self): for writer in self._writers: # If any new data is found (e.g. produced by other after_train), # write them before closing writer.write() writer.close() class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): """ Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook. Note that when used as a hook, it is unable to save additional data other than what's defined by the given `checkpointer`. It is executed every ``period`` iterations and after the last iteration. """ def before_train(self): self.max_iter = self.trainer.max_iter def after_step(self): # No way to use **kwargs self.step(self.trainer.iter) class BestCheckpointer(HookBase): """ Checkpoints best weights based off given metric. This hook should be used in conjunction to and executed after the hook that produces the metric, e.g. `EvalHook`. """ def __init__( self, eval_period: int, checkpointer: Checkpointer, val_metric: str, mode: str = "max", file_prefix: str = "model_best", ) -> None: """ Args: eval_period (int): the period `EvalHook` is set to run. checkpointer: the checkpointer object used to save checkpoints. val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50" mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be maximized or minimized, e.g. for "bbox/AP50" it should be "max" file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" """ self._logger = logging.getLogger(__name__) self._period = eval_period self._val_metric = val_metric assert mode in [ "max", "min", ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' if mode == "max": self._compare = operator.gt else: self._compare = operator.lt self._checkpointer = checkpointer self._file_prefix = file_prefix self.best_metric = None self.best_iter = None def _update_best(self, val, iteration): if math.isnan(val) or math.isinf(val): return False self.best_metric = val self.best_iter = iteration return True def _best_checking(self): metric_tuple = self.trainer.storage.latest().get(self._val_metric) if metric_tuple is None: self._logger.warning( f"Given val metric {self._val_metric} does not seem to be computed/stored." "Will not be checkpointing based on it." ) return else: latest_metric, metric_iter = metric_tuple if self.best_metric is None: if self._update_best(latest_metric, metric_iter): additional_state = {"iteration": metric_iter} self._checkpointer.save(f"{self._file_prefix}", **additional_state) self._logger.info( f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" ) elif self._compare(latest_metric, self.best_metric): additional_state = {"iteration": metric_iter} self._checkpointer.save(f"{self._file_prefix}", **additional_state) self._logger.info( f"Saved best model as latest eval score for {self._val_metric} is " f"{latest_metric:0.5f}, better than last best score " f"{self.best_metric:0.5f} @ iteration {self.best_iter}." ) self._update_best(latest_metric, metric_iter) else: self._logger.info( f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, " f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}." ) def after_step(self): # same conditions as `EvalHook` next_iter = self.trainer.iter + 1 if ( self._period > 0 and next_iter % self._period == 0 and next_iter != self.trainer.max_iter ): self._best_checking() def after_train(self): # same conditions as `EvalHook` if self.trainer.iter + 1 >= self.trainer.max_iter: self._best_checking() class LRScheduler(HookBase): """ A hook which executes a torch builtin LR scheduler and summarizes the LR. It is executed after every iteration. """ def __init__(self, optimizer=None, scheduler=None): """ Args: optimizer (torch.optim.Optimizer): scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): if a :class:`ParamScheduler` object, it defines the multiplier over the base LR in the optimizer. If any argument is not given, will try to obtain it from the trainer. """ self._optimizer = optimizer self._scheduler = scheduler def before_train(self): self._optimizer = self._optimizer or self.trainer.optimizer if isinstance(self.scheduler, ParamScheduler): self._scheduler = LRMultiplier( self._optimizer, self.scheduler, self.trainer.max_iter, last_iter=self.trainer.iter - 1, ) self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer) @staticmethod def get_best_param_group_id(optimizer): # NOTE: some heuristics on what LR to summarize # summarize the param group with most parameters largest_group = max(len(g["params"]) for g in optimizer.param_groups) if largest_group == 1: # If all groups have one parameter, # then find the most common initial LR, and use it for summary lr_count = Counter([g["lr"] for g in optimizer.param_groups]) lr = lr_count.most_common()[0][0] for i, g in enumerate(optimizer.param_groups): if g["lr"] == lr: return i else: for i, g in enumerate(optimizer.param_groups): if len(g["params"]) == largest_group: return i def after_step(self): lr = self._optimizer.param_groups[self._best_param_group_id]["lr"] self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) self.scheduler.step() @property def scheduler(self): return self._scheduler or self.trainer.scheduler def state_dict(self): if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): return self.scheduler.state_dict() return {} def load_state_dict(self, state_dict): if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): logger = logging.getLogger(__name__) logger.info("Loading scheduler from state_dict ...") self.scheduler.load_state_dict(state_dict) class TorchProfiler(HookBase): """ A hook which runs `torch.profiler.profile`. Examples: :: hooks.TorchProfiler( lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR ) The above example will run the profiler for iteration 10~20 and dump results to ``OUTPUT_DIR``. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in the ``chrome://tracing`` page in chrome browser, and the tensorboard visualizations can be visualized using ``tensorboard --logdir OUTPUT_DIR/log`` """ def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True): """ Args: enable_predicate (callable[trainer -> bool]): a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile. output_dir (str): the output directory to dump tracing files. activities (iterable): same as in `torch.profiler.profile`. save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/ """ self._enable_predicate = enable_predicate self._activities = activities self._output_dir = output_dir self._save_tensorboard = save_tensorboard def before_step(self): if self._enable_predicate(self.trainer): if self._save_tensorboard: on_trace_ready = torch.profiler.tensorboard_trace_handler( os.path.join( self._output_dir, "log", "profiler-tensorboard-iter{}".format(self.trainer.iter), ), f"worker{comm.get_rank()}", ) else: on_trace_ready = None self._profiler = torch.profiler.profile( activities=self._activities, on_trace_ready=on_trace_ready, record_shapes=True, profile_memory=True, with_stack=True, with_flops=True, ) self._profiler.__enter__() else: self._profiler = None def after_step(self): if self._profiler is None: return self._profiler.__exit__(None, None, None) if not self._save_tensorboard: PathManager.mkdirs(self._output_dir) out_file = os.path.join( self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter) ) if "://" not in out_file: self._profiler.export_chrome_trace(out_file) else: # Support non-posix filesystems with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d: tmp_file = os.path.join(d, "tmp.json") self._profiler.export_chrome_trace(tmp_file) with open(tmp_file) as f: content = f.read() with PathManager.open(out_file, "w") as f: f.write(content) class AutogradProfiler(TorchProfiler): """ A hook which runs `torch.autograd.profiler.profile`. Examples: :: hooks.AutogradProfiler( lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR ) The above example will run the profiler for iteration 10~20 and dump results to ``OUTPUT_DIR``. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in the ``chrome://tracing`` page in chrome browser. Note: When used together with NCCL on older version of GPUs, autograd profiler may cause deadlock because it unnecessarily allocates memory on every device it sees. The memory management calls, if interleaved with NCCL calls, lead to deadlock on GPUs that do not support ``cudaLaunchCooperativeKernelMultiDevice``. """ def __init__(self, enable_predicate, output_dir, *, use_cuda=True): """ Args: enable_predicate (callable[trainer -> bool]): a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile. output_dir (str): the output directory to dump tracing files. use_cuda (bool): same as in `torch.autograd.profiler.profile`. """ warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.") self._enable_predicate = enable_predicate self._use_cuda = use_cuda self._output_dir = output_dir def before_step(self): if self._enable_predicate(self.trainer): self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda) self._profiler.__enter__() else: self._profiler = None class EvalHook(HookBase): """ Run an evaluation function periodically, and at the end of training. It is executed every ``eval_period`` iterations and after the last iteration. """ def __init__(self, eval_period, eval_function, eval_after_train=True): """ Args: eval_period (int): the period to run `eval_function`. Set to 0 to not evaluate periodically (but still evaluate after the last iteration if `eval_after_train` is True). eval_function (callable): a function which takes no arguments, and returns a nested dict of evaluation metrics. eval_after_train (bool): whether to evaluate after the last iteration Note: This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (`eval_function=lambda: None`). """ self._period = eval_period self._func = eval_function self._eval_after_train = eval_after_train def _do_eval(self): results = self._func() if results: assert isinstance( results, dict ), "Eval function must return a dict. Got {} instead.".format(results) flattened_results = flatten_results_dict(results) for k, v in flattened_results.items(): try: v = float(v) except Exception as e: raise ValueError( "[EvalHook] eval_function should return a nested dict of float. " "Got '{}: {}' instead.".format(k, v) ) from e self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) # Evaluation may take different time among workers. # A barrier make them start the next iteration together. comm.synchronize() def after_step(self): next_iter = self.trainer.iter + 1 if self._period > 0 and next_iter % self._period == 0: # do the last eval in after_train if next_iter != self.trainer.max_iter: self._do_eval() def after_train(self): # This condition is to prevent the eval from running after a failed training if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter: self._do_eval() # func is likely a closure that holds reference to the trainer # therefore we clean it to avoid circular reference in the end del self._func class PreciseBN(HookBase): """ The standard implementation of BatchNorm uses EMA in inference, which is sometimes suboptimal. This class computes the true average of statistics rather than the moving average, and put true averages to every BN layer in the given model. It is executed every ``period`` iterations and after the last iteration. """ def __init__(self, period, model, data_loader, num_iter): """ Args: period (int): the period this hook is run, or 0 to not run during training. The hook will always run in the end of training. model (nn.Module): a module whose all BN layers in training mode will be updated by precise BN. Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered. data_loader (iterable): it will produce data to be run by `model(data)`. num_iter (int): number of iterations used to compute the precise statistics. """ self._logger = logging.getLogger(__name__) if len(get_bn_modules(model)) == 0: self._logger.info( "PreciseBN is disabled because model does not contain BN layers in training mode." ) self._disabled = True return self._model = model self._data_loader = data_loader self._num_iter = num_iter self._period = period self._disabled = False self._data_iter = None def after_step(self): next_iter = self.trainer.iter + 1 is_final = next_iter == self.trainer.max_iter if is_final or (self._period > 0 and next_iter % self._period == 0): self.update_stats() def update_stats(self): """ Update the model with precise statistics. Users can manually call this method. """ if self._disabled: return if self._data_iter is None: self._data_iter = iter(self._data_loader) def data_loader(): for num_iter in itertools.count(1): if num_iter % 100 == 0: self._logger.info( "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) ) # This way we can reuse the same iterator yield next(self._data_iter) with EventStorage(): # capture events in a new storage to discard them self._logger.info( "Running precise-BN for {} iterations... ".format(self._num_iter) + "Note that this could produce different statistics every time." ) update_bn_stats(self._model, data_loader(), self._num_iter) class TorchMemoryStats(HookBase): """ Writes pytorch's cuda memory statistics periodically. """ def __init__(self, period=20, max_runs=10): """ Args: period (int): Output stats each 'period' iterations max_runs (int): Stop the logging after 'max_runs' """ self._logger = logging.getLogger(__name__) self._period = period self._max_runs = max_runs self._runs = 0 def after_step(self): if self._runs > self._max_runs: return if (self.trainer.iter + 1) % self._period == 0 or ( self.trainer.iter == self.trainer.max_iter - 1 ): if torch.cuda.is_available(): max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0 reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0 max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0 self._logger.info( ( " iter: {} " " max_reserved_mem: {:.0f}MB " " reserved_mem: {:.0f}MB " " max_allocated_mem: {:.0f}MB " " allocated_mem: {:.0f}MB " ).format( self.trainer.iter, max_reserved_mb, reserved_mb, max_allocated_mb, allocated_mb, ) ) self._runs += 1 if self._runs == self._max_runs: mem_summary = torch.cuda.memory_summary() self._logger.info("\n" + mem_summary) torch.cuda.reset_peak_memory_stats() ================================================ FILE: llava/model/semsam/modules/matcher.py ================================================ # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. """ Modules to compute the matching cost and solve the corresponding LSAP. """ import torch import torch.nn.functional as F import numpy as np from scipy.optimize import linear_sum_assignment from torch import nn from torch.cuda.amp import autocast from detectron2.projects.point_rend.point_features import point_sample from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy # from ..language.loss import vl_similarity def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss batch_dice_loss_jit = torch.jit.script( batch_dice_loss ) # type: torch.jit.ScriptModule def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), reduction="none" ) loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( "nc,mc->nm", neg, (1 - targets) ) return loss / hw batch_sigmoid_ce_loss_jit = torch.jit.script( batch_sigmoid_ce_loss ) # type: torch.jit.ScriptModule class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, cost_box: float = 0, cost_giou: float = 0, panoptic_on: bool = False): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice self.cost_box = cost_box self.cost_giou = cost_giou self.panoptic_on = panoptic_on assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" self.num_points = num_points @torch.no_grad() def memory_efficient_forward(self, outputs, targets, cost=["cls", "box", "mask"]): """More memory-friendly matching. Change cost to compute only certain loss in matching""" bs, num_queries = outputs["pred_logits"].shape[:2] indices = [] # Iterate through batch size for b in range(bs): out_bbox = outputs["pred_boxes"][b] if 'box' in cost: tgt_bbox=targets[b]["boxes"] cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) else: cost_bbox = torch.tensor(0).to(out_bbox) cost_giou = torch.tensor(0).to(out_bbox) out_prob = outputs["pred_logits"][b].sigmoid() # [num_queries, num_classes] tgt_ids = targets[b]["labels"] # focal loss alpha = 0.25 gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-6).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-6).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. # cost_class = -out_prob[:, tgt_ids] if 'mask' in cost: out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["masks"].to(out_mask) out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype) # get gt labels tgt_mask = point_sample( tgt_mask, point_coords.repeat(tgt_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask, point_coords.repeat(out_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) else: cost_mask = torch.tensor(0).to(out_bbox) cost_dice = torch.tensor(0).to(out_bbox) # Final cost matrix if self.panoptic_on: isthing = tgt_ids<80 cost_bbox[:, ~isthing] = cost_bbox[:, isthing].mean() cost_giou[:, ~isthing] = cost_giou[:, isthing].mean() cost_bbox[cost_bbox.isnan()] = 0.0 cost_giou[cost_giou.isnan()] = 0.0 C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + self.cost_box*cost_bbox + self.cost_giou*cost_giou ) C = C.reshape(num_queries, -1).cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def grounding_forward(self, outputs, targets, extra): """More memory-friendly matching""" bs, num_queries = outputs["pred_gmasks"].shape[:2] if bs == 0 or len(targets) == 0: return None indices = [] # Iterate through batch size for b in range(bs): out_prob = outputs["pred_logits"][b] # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -out_prob.softmax(dim=0) out_mask = outputs["pred_gmasks"][b] # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["grounding_masks"].to(out_mask) out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype) # get gt labels tgt_mask = point_sample( tgt_mask, point_coords.repeat(tgt_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask, point_coords.repeat(out_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) # Final cost matrix C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice ) C = C.reshape(num_queries, -1).cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def caption_forward_womask(self, outputs, targets, extra): """More memory-friendly matching""" bs, _ = outputs["pred_logits"].shape[:2] if bs == 0 or len(targets) == 0: return None indices = [] t_emb = torch.cat([t['captions'] for t in targets]) v_emb = outputs['unmatched_pred_captions'] caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets]) # Iterate through batch size for b in range(bs): v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7) num_queries = len(v_emb[b]) out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0] tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])] # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -out_prob[:, tgt_ids] # Final cost matrix C = (self.cost_class * cost_class) C = C.reshape(num_queries, -1).cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def caption_forward_wmask(self, outputs, targets, extra): """More memory-friendly matching""" bs, _ = outputs["pred_logits"].shape[:2] if bs == 0 or len(targets) == 0: return None indices = [] t_emb = torch.cat([t['captions'] for t in targets]) v_emb = outputs['unmatched_pred_captions'] caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets]) # Iterate through batch size for b in range(bs): v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7) num_queries = len(v_emb[b]) out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0] tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])] # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -out_prob[:, tgt_ids] out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["masks"].to(out_mask) out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype) # get gt labels tgt_mask = point_sample( tgt_mask, point_coords.repeat(tgt_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask, point_coords.repeat(out_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) # Final cost matrix C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice ) C = C.reshape(num_queries, -1).cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] @torch.no_grad() def forward(self, outputs, targets, cost=["cls", "box", "mask"], mode='default', extra={}): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ if mode == 'default': return self.memory_efficient_forward(outputs, targets, cost) elif mode == 'grounding': return self.grounding_forward(outputs, targets, extra) elif mode == 'caption_womask': return self.caption_forward_womask(outputs, targets, extra) elif mode == 'caption_wmask': return self.caption_forward_wmask(outputs, targets, extra) else: assert False, "Mode {} is not supported.".format(mode) def __repr__(self, _repr_indent=4): head = "Matcher " + self.__class__.__name__ body = [ "cost_class: {}".format(self.cost_class), "cost_mask: {}".format(self.cost_mask), "cost_dice: {}".format(self.cost_dice), ] lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/semsam/modules/point_features.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.layers import cat, shapes_to_tensor from detectron2.structures import BitMasks, Boxes # from ..layers import cat, shapes_to_tensor # from ..structures import BitMasks, Boxes """ Shape shorthand in this module: N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the number of images for semantic segmenation. R: number of ROIs, combined over all images, in the minibatch P: number of points """ def point_sample(input, point_coords, **kwargs): """ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside [0, 1] x [0, 1] square. Args: input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains [0, 1] x [0, 1] normalized point coordinates. Returns: output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains features for points in `point_coords`. The features are obtained via bilinear interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. """ add_dim = False if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output def generate_regular_grid_point_coords(R, side_size, device): """ Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. Args: R (int): The number of grids to sample, one for each region. side_size (int): The side size of the regular grid. device (torch.device): Desired device of returned tensor. Returns: (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates for the regular grids. """ aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) return r.view(1, -1, 2).expand(R, -1, -1) def get_uncertain_point_coords_with_randomness( coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio ): """ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. See PointRend paper for details. Args: coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for class-specific or class-agnostic prediction. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that contains logit predictions for P points and returns their uncertainties as a Tensor of shape (N, 1, P). num_points (int): The number of points P to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. """ assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype) point_logits = point_sample(coarse_logits, point_coords, align_corners=False) # It is crucial to calculate uncertainty based on the sampled prediction value for the points. # Calculating uncertainties of the coarse predictions first and sampling them for points leads # to incorrect results. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. # However, if we calculate uncertainties for the coarse predictions first, # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( num_boxes, num_uncertain_points, 2 ) if num_random_points > 0: point_coords = cat( [ point_coords, torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty values for a set of points on a regular H x W grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W grid. """ R, _, H, W = uncertainty_map.shape h_step = 1.0 / float(H) w_step = 1.0 / float(W) num_points = min(H * W, num_points) point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step return point_indices, point_coords def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): """ Get features from feature maps in `features_list` that correspond to specific point coordinates inside each bounding box from `boxes`. Args: features_list (list[Tensor]): A list of feature map tensors to get features from. feature_scales (list[float]): A list of scales for tensors in `features_list`. boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all together. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled from all features maps in feature_list for P sampled points for all R boxes in `boxes`. point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level coordinates of P points. """ cat_boxes = Boxes.cat(boxes) num_boxes = [b.tensor.size(0) for b in boxes] point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) point_features = [] for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): point_features_per_image = [] for idx_feature, feature_map in enumerate(features_list): h, w = feature_map.shape[-2:] scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature] point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device) point_features_per_image.append( point_sample( feature_map[idx_img].unsqueeze(0), point_coords_scaled.unsqueeze(0), align_corners=False, ) .squeeze(0) .transpose(1, 0) ) point_features.append(cat(point_features_per_image, dim=1)) return cat(point_features, dim=0), point_coords_wrt_image def get_point_coords_wrt_image(boxes_coords, point_coords): """ Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. Args: boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. coordinates. point_coords (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. Returns: point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-normalized coordinates of P sampled points. """ with torch.no_grad(): point_coords_wrt_image = point_coords.clone() point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( boxes_coords[:, None, 2] - boxes_coords[:, None, 0] ) point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( boxes_coords[:, None, 3] - boxes_coords[:, None, 1] ) point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] return point_coords_wrt_image def sample_point_labels(instances, point_coords): """ Sample point labels from ground truth mask given point_coords. Args: instances (list[Instances]): A list of N Instances, where N is the number of images in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. The ground-truth gt_masks in each instance will be used to compute labels. points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of instances and P is the number of points for each instance. The coordinates are in the absolute image pixel coordinate space, i.e. [0, H] x [0, W]. Returns: Tensor: A tensor of shape (R, P) that contains the labels of P sampled points. """ with torch.no_grad(): gt_mask_logits = [] point_coords_splits = torch.split( point_coords, [len(instances_per_image) for instances_per_image in instances] ) for i, instances_per_image in enumerate(instances): if len(instances_per_image) == 0: continue assert isinstance( instances_per_image.gt_masks, BitMasks ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." gt_bit_masks = instances_per_image.gt_masks.tensor h, w = instances_per_image.gt_masks.image_size scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) points_coord_grid_sample_format = point_coords_splits[i] / scale gt_mask_logits.append( point_sample( gt_bit_masks.to(torch.float32).unsqueeze(1), points_coord_grid_sample_format, align_corners=False, ).squeeze(1) ) point_labels = cat(gt_mask_logits) return point_labels ================================================ FILE: llava/model/semsam/modules/position_encoding.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py """ Various positional encodings for the transformer. """ import math import torch from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x, mask=None): if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=x.dtype) x_embed = not_mask.cumsum(2, dtype=x.dtype) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ================================================ FILE: llava/model/semsam/modules/postprocessing.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import torch from torch.nn import functional as F from detectron2.structures import Instances, ROIMasks # perhaps should rename to "resize_instance" def detector_postprocess( results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 ): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ if isinstance(output_width, torch.Tensor): # This shape might (but not necessarily) be tensors during tracing. # Converts integer tensors to float temporaries to ensure true # division is performed when computing scale_x and scale_y. output_width_tmp = output_width.float() output_height_tmp = output_height.float() new_size = torch.stack([output_height, output_width]) else: new_size = (output_height, output_width) output_width_tmp = output_width output_height_tmp = output_height scale_x, scale_y = ( output_width_tmp / results.image_size[1], output_height_tmp / results.image_size[0], ) results = Instances(new_size, **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes else: output_boxes = None assert output_boxes is not None, "Predictions must contain boxes!" output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): if isinstance(results.pred_masks, ROIMasks): roi_masks = results.pred_masks else: # pred_masks is a tensor of shape (N, 1, M, M) roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) results.pred_masks = roi_masks.to_bitmasks( results.pred_boxes, output_height, output_width, mask_threshold ).tensor # TODO return ROIMasks/BitMask object in the future if results.has("pred_keypoints"): results.pred_keypoints[:, :, 0] *= scale_x results.pred_keypoints[:, :, 1] *= scale_y return results def bbox_postprocess(result, input_size, img_size, output_height, output_width): """ result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] """ if result is None: return None scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) result = result.sigmoid() * scale x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 h,w = img_size x1 = x1.clamp(min=0, max=w) y1 = y1.clamp(min=0, max=h) x2 = x2.clamp(min=0, max=w) y2 = y2.clamp(min=0, max=h) box = torch.stack([x1,y1,x2,y2]).permute(1,0) scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) box = box*scale return box def sem_seg_postprocess(result, img_size, output_height, output_width): """ Return semantic segmentation predictions in the original resolution. The input images are often resized when entering semantic segmentor. Moreover, in same cases, they also padded inside segmentor to be divisible by maximum network stride. As a result, we often need the predictions of the segmentor in a different resolution from its inputs. Args: result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), where C is the number of classes, and H, W are the height and width of the prediction. img_size (tuple): image size that segmentor is taking as input. output_height, output_width: the desired output resolution. Returns: semantic segmentation prediction (Tensor): A tensor of the shape (C, output_height, output_width) that contains per-pixel soft predictions. """ result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) result = F.interpolate( result, size=(output_height, output_width), mode="bicubic", align_corners=False )[0] return result ================================================ FILE: llava/model/semsam/utils/__init__.py ================================================ from .config import * from .misc import * # from .dist import * ================================================ FILE: llava/model/semsam/utils/box_ops.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for bounding box manipulation and GIoU. """ import torch from torchvision.ops.boxes import box_area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) def box_xywh_to_xyxy(x): x0, y0, x1, y1 = x.unbind(-1) b = [x0, y0, (x0 + x1), (y0 + y1)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / (union+1e-6) return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / (area+1e-6) def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) ================================================ FILE: llava/model/semsam/utils/config.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import functools import inspect def configurable(init_func=None, *, from_config=None): """ Decorate a function or a class's __init__ method so that it can be called with a :class:`CfgNode` object using a :func:`from_config` function that translates :class:`CfgNode` to arguments. Examples: :: # Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite Args: init_func (callable): a class's ``__init__`` method in usage 1. The class must have a ``from_config`` classmethod which takes `cfg` as the first argument. from_config (callable): the from_config function in usage 2. It must take `cfg` as its first argument. """ if init_func is not None: assert ( inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__" ), "Incorrect use of @configurable. Check API documentation for examples." @functools.wraps(init_func) def wrapped(self, *args, **kwargs): try: from_config_func = type(self).from_config except AttributeError as e: raise AttributeError( "Class with @configurable must have a 'from_config' classmethod." ) from e if not inspect.ismethod(from_config_func): raise TypeError("Class with @configurable must have a 'from_config' classmethod.") # import ipdb; ipdb.set_trace() if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) init_func(self, **explicit_args) else: init_func(self, *args, **kwargs) return wrapped else: if from_config is None: return configurable # @configurable() is made equivalent to @configurable assert inspect.isfunction( from_config ), "from_config argument of configurable must be a function!" def wrapper(orig_func): @functools.wraps(orig_func) def wrapped(*args, **kwargs): if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config, *args, **kwargs) return orig_func(**explicit_args) else: return orig_func(*args, **kwargs) wrapped.from_config = from_config return wrapped return wrapper def _called_with_cfg(*args, **kwargs): """ Returns: bool: whether the arguments contain CfgNode and should be considered forwarded to from_config. """ from omegaconf import DictConfig, OmegaConf, ListConfig # from detectron2.config import LazyConfig if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))): return True if isinstance(kwargs.pop("cfg", None), (dict)): return True # `from_config`'s first argument is forced to be "cfg". # So the above check covers all cases. return False def _get_args_from_config(from_config_func, *args, **kwargs): """ Use `from_config` to obtain explicit arguments. Returns: dict: arguments to be used for cls.__init__ """ signature = inspect.signature(from_config_func) if list(signature.parameters.keys())[0] != "cfg": if inspect.isfunction(from_config_func): name = from_config_func.__name__ else: name = f"{from_config_func.__self__}.from_config" raise TypeError(f"{name} must take 'cfg' as the first argument!") support_var_arg = any( param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] for param in signature.parameters.values() ) if support_var_arg: # forward all arguments to from_config, if from_config accepts them ret = from_config_func(*args, **kwargs) else: # forward supported arguments to from_config supported_arg_names = set(signature.parameters.keys()) extra_kwargs = {} for name in list(kwargs.keys()): if name not in supported_arg_names: extra_kwargs[name] = kwargs.pop(name) ret = from_config_func(*args, **kwargs) # forward the other arguments to __init__ ret.update(extra_kwargs) return ret ================================================ FILE: llava/model/semsam/utils/misc.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Modified by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ from typing import List, Optional import torch import torch.distributed as dist import torchvision from torch import Tensor from utils.constants import * def get_iou(gt_masks, pred_masks, ignore_label=-1): rev_ignore_mask = ~(gt_masks == ignore_label) gt_masks = gt_masks.bool() n,h,w = gt_masks.shape intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1) union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1) ious = (intersection / union) return ious def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False elif tensor_list[0].ndim == 2: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(txt.shape) for txt in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, l = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, l), dtype=torch.bool, device=device) for txt, pad_txt, m in zip(tensor_list, tensor, mask): pad_txt[: txt.shape[0], : txt.shape[1]] = txt m[: txt.shape[1]] = False else: raise ValueError("not supported") return NestedTensor(tensor, mask) def _collate_and_pad_divisibility(tensor_list: list, div=32): max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) c,h,w = max_size pad_h = (div - h % div) if h % div != 0 else 0 pad_w = (div - w % div) if w % div != 0 else 0 max_size = (c,h+pad_h,w+pad_w) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) return padded_imgs # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_class_names(name, background=True): if name is None: return None if 'refcoco' in name: class_names = ["none"] elif 'pascal' in name: class_names = PASCAL_CLASSES_PART + ["background"] elif 'sam' in name: class_names = ['foreground'] + ["background"] elif 'coco' in name and 'pan' not in name: class_names = COCO_INSTANCE_CLASSES + ["background"] elif 'coco' in name: class_names = COCO_PANOPTIC_CLASSES + ["background"] elif 'ade20k_full' in name: class_names = ADE20K_847 + ["background"] elif 'ade' in name: class_names = ADE_PANOPTIC_CLASSES + ["background"] elif 'voc' in name: class_names = PASCAL_CLASSES + ["background"] elif 'vlp' in name: class_names = ["noun"] elif 'tsv' in name: class_names = ["noun"] elif 'phrasecut' in name: class_names = ["noun"] elif 'mapillary' in name: class_names =MAPILLARY_VISTAS_SEM_SEG_CATEGORIES elif 'openimage' in name: class_names = ["noun"] elif 'imagenet' in name: class_names = IMAGENET_CLASSES elif 'context_459' in name: class_names = PASCAL_CONTEXT_459 + ["background"] elif 'context_59' in name: class_names = PASCAL_CONTEXT_59 + ["background"] elif 'context_33' in name: class_names = PASCAL_CONTEXT_33 elif 'sunrgbd_37' in name: class_names = SUN_RGBD_37 elif 'scannet_41' in name: class_names = SCAN_40 elif 'scannet_38' in name: class_names = SCAN_37 elif 'scannet_21' in name: class_names = SCAN_20 elif 'object365' in name: class_names = OBJECT365 elif 'lvis' in name: class_names = LVIS_CATEGORIES elif 'seginw' in name: class_names = SEGINW_CATEGORIES[name.replace('_train', '').replace('_val', '')] + ["background"] elif name == 'cityscapes_fine_sem_seg_val': class_names = CITYSCAPES elif name == 'cityscapes_fine_instance_seg_val': class_names = CITYSCAPES_THING + ["background"] elif name in ['cityscapes_fine_panoptic_val', 'cityscapes_fine_panoptic_train']: class_names = CITYSCAPES + ["background"] elif name == 'bdd10k_val_sem_seg': class_names = BDD_SEM elif name == 'bdd10k_40_panoptic_val': class_names = BDD_PANO else: class_names=["none"] # assert False, "text dataset name {} is not defined".format(name) if background == False and "background" in class_names: class_names.pop(class_names.index("background")) return class_names # TODO: add background to # def get_class_names(name): # if name is None: # return None # elif 'refcoco' in name: # return ["background"] # elif 'coco' in name: # return COCO_PANOPTIC_CLASSES + ["background"] # elif 'ade20k_full' in name: # return ADE20K_847 + ["background"] # elif 'ade' in name: # return ADE_PANOPTIC_CLASSES + ["background"] # elif 'scannet_41' in name: # return SCAN_40 # elif 'scannet_21' in name: # return SCAN_20 # elif 'sun' in name: # return SUN_RGBD_37 # elif name == 'cityscapes_fine_sem_seg_val': # return CITYSCAPES + ["background"] # elif name == 'cityscapes_fine_instance_seg_val': # return CITYSCAPES_THING + ["background"] # elif name in ['cityscapes_fine_panoptic_val']: # return CITYSCAPES + ["background"] # elif name == 'bdd10k_val_sem_seg': # return BDD_SEM + ["background"] # elif name == 'bdd10k_40_panoptic_val': # return BDD_PANO + ["background"] # elif 'vlp' in name: # return ["background"] # else: # assert False, "text dataset name {} is not defined".format(name) ================================================ FILE: llava/model/utils.py ================================================ from transformers import AutoConfig def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if 'llava' in config and 'llava' not in cfg.model_type: assert cfg.model_type == 'llama' print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") print("You must upgrade the checkpoint to the new code base (this can be done automatically).") confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") cfg.architectures[0] = 'LlavaLlamaForCausalLM' cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) ================================================ FILE: llava/serve/__init__.py ================================================ ================================================ FILE: llava/serve/cli.py ================================================ import argparse import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import requests from PIL import Image from io import BytesIO from transformers import TextStreamer def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def main(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) if 'llama-2' in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles image = load_image(args.image_file) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() while True: try: inp = input(f"{roles[0]}: ") except EOFError: inp = "" if not inp: print("exit...") break print(f"{roles[1]}: ", end="") if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp conv.append_message(conv.roles[0], inp) image = None else: # later messages conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() conv.messages[-1][-1] = outputs if args.debug: print("\n", {"prompt": prompt, "outputs": outputs}, "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--num-gpus", type=int, default=1) parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() main(args) ================================================ FILE: llava/serve/controller.py ================================================ """ A controller manages distributed workers. It sends worker addresses to clients. """ import argparse import asyncio import dataclasses from enum import Enum, auto import json import logging import time from typing import List, Union import threading from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import numpy as np import requests import uvicorn from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION from llava.utils import build_logger, server_error_msg logger = build_logger("controller", "controller.log") class DispatchMethod(Enum): LOTTERY = auto() SHORTEST_QUEUE = auto() @classmethod def from_str(cls, name): if name == "lottery": return cls.LOTTERY elif name == "shortest_queue": return cls.SHORTEST_QUEUE else: raise ValueError(f"Invalid dispatch method") @dataclasses.dataclass class WorkerInfo: model_names: List[str] speed: int queue_length: int check_heart_beat: bool last_heart_beat: str def heart_beat_controller(controller): while True: time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) controller.remove_stable_workers_by_expiration() class Controller: def __init__(self, dispatch_method: str): # Dict[str -> WorkerInfo] self.worker_info = {} self.dispatch_method = DispatchMethod.from_str(dispatch_method) self.heart_beat_thread = threading.Thread( target=heart_beat_controller, args=(self,)) self.heart_beat_thread.start() logger.info("Init controller") def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict): if worker_name not in self.worker_info: logger.info(f"Register a new worker: {worker_name}") else: logger.info(f"Register an existing worker: {worker_name}") if not worker_status: worker_status = self.get_worker_status(worker_name) if not worker_status: return False self.worker_info[worker_name] = WorkerInfo( worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time()) logger.info(f"Register done: {worker_name}, {worker_status}") return True def get_worker_status(self, worker_name: str): try: r = requests.post(worker_name + "/worker_get_status", timeout=5) except requests.exceptions.RequestException as e: logger.error(f"Get status fails: {worker_name}, {e}") return None if r.status_code != 200: logger.error(f"Get status fails: {worker_name}, {r}") return None return r.json() def remove_worker(self, worker_name: str): del self.worker_info[worker_name] def refresh_all_workers(self): old_info = dict(self.worker_info) self.worker_info = {} for w_name, w_info in old_info.items(): if not self.register_worker(w_name, w_info.check_heart_beat, None): logger.info(f"Remove stale worker: {w_name}") def list_models(self): model_names = set() for w_name, w_info in self.worker_info.items(): model_names.update(w_info.model_names) return list(model_names) def get_worker_address(self, model_name: str): if self.dispatch_method == DispatchMethod.LOTTERY: worker_names = [] worker_speeds = [] for w_name, w_info in self.worker_info.items(): if model_name in w_info.model_names: worker_names.append(w_name) worker_speeds.append(w_info.speed) worker_speeds = np.array(worker_speeds, dtype=np.float32) norm = np.sum(worker_speeds) if norm < 1e-4: return "" worker_speeds = worker_speeds / norm if True: # Directly return address pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] return worker_name # Check status before returning while True: pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] if self.get_worker_status(worker_name): break else: self.remove_worker(worker_name) worker_speeds[pt] = 0 norm = np.sum(worker_speeds) if norm < 1e-4: return "" worker_speeds = worker_speeds / norm continue return worker_name elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: worker_names = [] worker_qlen = [] for w_name, w_info in self.worker_info.items(): if model_name in w_info.model_names: worker_names.append(w_name) worker_qlen.append(w_info.queue_length / w_info.speed) if len(worker_names) == 0: return "" min_index = np.argmin(worker_qlen) w_name = worker_names[min_index] self.worker_info[w_name].queue_length += 1 logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") return w_name else: raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") def receive_heart_beat(self, worker_name: str, queue_length: int): if worker_name not in self.worker_info: logger.info(f"Receive unknown heart beat. {worker_name}") return False self.worker_info[worker_name].queue_length = queue_length self.worker_info[worker_name].last_heart_beat = time.time() logger.info(f"Receive heart beat. {worker_name}") return True def remove_stable_workers_by_expiration(self): expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION to_delete = [] for worker_name, w_info in self.worker_info.items(): if w_info.check_heart_beat and w_info.last_heart_beat < expire: to_delete.append(worker_name) for worker_name in to_delete: self.remove_worker(worker_name) def worker_api_generate_stream(self, params): worker_addr = self.get_worker_address(params["model"]) if not worker_addr: logger.info(f"no worker: {params['model']}") ret = { "text": server_error_msg, "error_code": 2, } yield json.dumps(ret).encode() + b"\0" try: response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: yield chunk + b"\0" except requests.exceptions.RequestException as e: logger.info(f"worker timeout: {worker_addr}") ret = { "text": server_error_msg, "error_code": 3, } yield json.dumps(ret).encode() + b"\0" # Let the controller act as a worker to achieve hierarchical # management. This can be used to connect isolated sub networks. def worker_api_get_status(self): model_names = set() speed = 0 queue_length = 0 for w_name in self.worker_info: worker_status = self.get_worker_status(w_name) if worker_status is not None: model_names.update(worker_status["model_names"]) speed += worker_status["speed"] queue_length += worker_status["queue_length"] return { "model_names": list(model_names), "speed": speed, "queue_length": queue_length, } app = FastAPI() @app.post("/register_worker") async def register_worker(request: Request): data = await request.json() controller.register_worker( data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)) @app.post("/refresh_all_workers") async def refresh_all_workers(): models = controller.refresh_all_workers() @app.post("/list_models") async def list_models(): models = controller.list_models() return {"models": models} @app.post("/get_worker_address") async def get_worker_address(request: Request): data = await request.json() addr = controller.get_worker_address(data["model"]) return {"address": addr} @app.post("/receive_heart_beat") async def receive_heart_beat(request: Request): data = await request.json() exist = controller.receive_heart_beat( data["worker_name"], data["queue_length"]) return {"exist": exist} @app.post("/worker_generate_stream") async def worker_api_generate_stream(request: Request): params = await request.json() generator = controller.worker_api_generate_stream(params) return StreamingResponse(generator) @app.post("/worker_get_status") async def worker_api_get_status(request: Request): return controller.worker_api_get_status() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21001) parser.add_argument("--dispatch-method", type=str, choices=[ "lottery", "shortest_queue"], default="shortest_queue") args = parser.parse_args() logger.info(f"args: {args}") controller = Controller(args.dispatch_method) uvicorn.run(app, host=args.host, port=args.port, log_level="info") ================================================ FILE: llava/serve/gradio_web_server.py ================================================ import argparse import datetime import json import os import time import gradio as gr import requests from llava.conversation import (default_conversation, conv_templates, SeparatorStyle) from llava.constants import LOGDIR from llava.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) import hashlib logger = build_logger("gradio_web_server", "gradio_web_server.log") headers = {"User-Agent": "LLaVA Client"} no_change_btn = gr.Button.update() enable_btn = gr.Button.update(interactive=True) disable_btn = gr.Button.update(interactive=False) priority = { "vicuna-13b": "aaaaaaa", "koala-13b": "aaaaaab", } def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def get_model_list(): ret = requests.post(args.controller_url + "/refresh_all_workers") assert ret.status_code == 200 ret = requests.post(args.controller_url + "/list_models") models = ret.json()["models"] models.sort(key=lambda x: priority.get(x, x)) logger.info(f"Models: {models}") return models get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") dropdown_update = gr.Dropdown.update(visible=True) if "model" in url_params: model = url_params["model"] if model in models: dropdown_update = gr.Dropdown.update( value=model, visible=True) state = default_conversation.copy() return (state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True)) def load_demo_refresh_model_list(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") models = get_model_list() state = default_conversation.copy() return (state, gr.Dropdown.update( choices=models, value=models[0] if len(models) > 0 else ""), gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True)) def vote_last_response(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, model_selector, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, model_selector, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, model_selector, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def add_text(state, text, image, image_process_mode, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0 and image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 if args.moderate: flagged = violates_moderation(text) if flagged: state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( no_change_btn,) * 5 text = text[:1536] # Hard cut-off if image is not None: text = text[:1200] # Hard cut-off for images if '' not in text: # text = '' + text text = text + '\n' text = (text, image, image_process_mode) if len(state.get_images(return_pil=True)) > 0: state = default_conversation.copy() state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: # First round of conversation if "llava" in model_name.lower(): if 'llama-2' in model_name.lower(): template_name = "llava_llama_2" elif "v1" in model_name.lower(): if 'mmtag' in model_name.lower(): template_name = "v1_mmtag" elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): template_name = "v1_mmtag" else: template_name = "llava_v1" elif "mpt" in model_name.lower(): template_name = "mpt" else: if 'mmtag' in model_name.lower(): template_name = "v0_mmtag" elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): template_name = "v0_mmtag" else: template_name = "llava_v0" elif "mpt" in model_name: template_name = "mpt_text" elif "llama-2" in model_name: template_name = "llama_2" else: template_name = "vicuna_v1" new_state = conv_templates[template_name].copy() new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # Query worker address controller_url = args.controller_url ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name}) worker_addr = ret.json()["address"] logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") # No available worker if worker_addr == "": state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return # Construct prompt prompt = state.get_prompt() all_images = state.get_images(return_pil=True) all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] for image, hash in zip(all_images, all_image_hash): t = datetime.datetime.now() filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) # Make requests pload = { "model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p), "max_new_tokens": min(int(max_new_tokens), 1536), "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, "images": f'List of {len(state.get_images())} images: {all_image_hash}', } logger.info(f"==== request ====\n{pload}") pload['images'] = state.get_images() state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][len(prompt):].strip() state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return time.sleep(0.03) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 finish_tstamp = time.time() logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "start": round(start_tstamp, 4), "finish": round(start_tstamp, 4), "state": state.dict(), "images": all_image_hash, "ip": request.client.host, } fout.write(json.dumps(data) + "\n") title_markdown = (""" # 🌋 LLaVA: Large Language and Vision Assistant [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0) """) tos_markdown = (""" ### Terms of use By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. """) learn_more_markdown = (""" ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) def build_demo(embed_mode): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False) with gr.Blocks(title="LLaVA", theme=gr.themes.Base()) as demo: state = gr.State() if not embed_mode: gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=3): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad"], value="Crop", label="Preprocess for non-square image") cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples(examples=[ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], ], inputs=[imagebox, textbox]) with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) with gr.Column(scale=6): chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False, height=550) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Submit", visible=False) with gr.Row(visible=False) as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠️ Flag", interactive=False) #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) if not embed_mode: gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) url_params = gr.JSON(visible=False) # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list) clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list) submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list) if args.model_list_mode == "once": demo.load(load_demo, [url_params], [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row], _js=get_window_url_params) elif args.model_list_mode == "reload": demo.load(load_demo_refresh_model_list, None, [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row]) else: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") return demo if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--controller-url", type=str, default="http://localhost:21001") parser.add_argument("--concurrency-count", type=int, default=8) parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") args = parser.parse_args() logger.info(f"args: {args}") models = get_model_list() logger.info(args) demo = build_demo(args.embed) demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch( server_name=args.host, server_port=args.port, share=args.share) ================================================ FILE: llava/serve/register_worker.py ================================================ """ Manually register workers. Usage: python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 """ import argparse import requests if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--controller-address", type=str) parser.add_argument("--worker-name", type=str) parser.add_argument("--check-heart-beat", action="store_true") args = parser.parse_args() url = args.controller_address + "/register_worker" data = { "worker_name": args.worker_name, "check_heart_beat": args.check_heart_beat, "worker_status": None, } r = requests.post(url, json=data) assert r.status_code == 200 ================================================ FILE: llava/serve/test_message.py ================================================ import argparse import json import requests from llava.conversation import default_conversation def main(): if args.worker_address: worker_addr = args.worker_address else: controller_addr = args.controller_address ret = requests.post(controller_addr + "/refresh_all_workers") ret = requests.post(controller_addr + "/list_models") models = ret.json()["models"] models.sort() print(f"Models: {models}") ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) worker_addr = ret.json()["address"] print(f"worker_addr: {worker_addr}") if worker_addr == "": return conv = default_conversation.copy() conv.append_message(conv.roles[0], args.message) prompt = conv.get_prompt() headers = {"User-Agent": "LLaVA Client"} pload = { "model": args.model_name, "prompt": prompt, "max_new_tokens": args.max_new_tokens, "temperature": 0.7, "stop": conv.sep, } response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) print(prompt.replace(conv.sep, "\n"), end="") for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"].split(conv.sep)[-1] print(output, end="\r") print("") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--controller-address", type=str, default="http://localhost:21001") parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--max-new-tokens", type=int, default=32) parser.add_argument("--message", type=str, default= "Tell me a story with more than 1000 words.") args = parser.parse_args() main() ================================================ FILE: llava/train/llama_flash_attn_monkey_patch.py ================================================ from typing import List, Optional, Tuple import logging import torch from torch import nn import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from einops import rearrange try: from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func except ImportError: from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel attention_mask: [bsz, q_len] """ bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] assert past_key_value is None, "past_key_value is not supported" cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" # Flash attention codes from # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # transform the data into the format required by flash attention qkv = torch.stack( [query_states, key_states, value_states], dim=2 ) # [bsz, nh, 3, q_len, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask = attention_mask if key_padding_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len cu_q_lens = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads ) output_unpad = flash_attn_unpadded_qkvpacked_func( x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len ), "b s (h d) -> b s h d", h=nheads, ) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): cuda_major, cuda_minor = torch.cuda.get_device_capability() if cuda_major < 8: logging.warning( "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward ================================================ FILE: llava/train/llava_trainer.py ================================================ import os import torch from transformers import Trainer from typing import Optional def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return class LLaVATrainer(Trainer): def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ['mm_projector'] if getattr(self.args, "use_im_start_end", False) or getattr(self.args, "new_tokens", False): keys_to_match.extend(['embed_tokens', 'embed_in','lm_head']) # import pdb; pdb.set_trace() weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): pass else: super(LLaVATrainer, self)._save(output_dir, state_dict) ================================================ FILE: llava/train/llava_trainer_gd.py ================================================ import os import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler # from transformers import Trainer from typing import Optional from transformers.trainer import * from datasets_os import build_train_dataloader def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return class TrainerLLavaGD(Trainer): """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers models. args ([`TrainingArguments`], *optional*): The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. data_collator (`DataCollator`, *optional*): The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will default to [`default_data_collator`] if no `tokenizer` is provided, an instance of [`DataCollatorWithPadding`] otherwise. train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. tokenizer ([`PreTrainedTokenizerBase`], *optional*): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. model_init (`Callable[[], PreTrainedModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start from a new instance of the model as given by this function. The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return a dictionary string to metric values. callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received by `compute_metrics`. Note that the labels (second parameter) will be `None` if the dataset does not have them. Important attributes: - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] subclass. - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs). - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set to `False` if model parallel or deepspeed is used, or if the default `TrainingArguments.place_model_on_device` is overridden to return `False` . - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while in `train`) """ # Those are used as methods of the Trainer in examples. def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, data_loader_args=None, cfg=None, ): self.cfg=cfg if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) self.args = args # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None self.deepspeed = None self.is_in_train = False self.data_loader_args=data_loader_args self.create_accelerator_and_postprocess() # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() # set the correct log level depending on the node log_level = args.get_process_log_level() logging.set_verbosity(log_level) # force device and distributed setup init explicitly args._setup_devices if model is None: if model_init is not None: self.model_init = model_init model = self.call_model_init() else: raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") else: if model_init is not None: warnings.warn( "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" " overwrite your model when calling the `train` method. This will become a fatal error in the next" " release.", FutureWarning, ) self.model_init = model_init if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " "computes hidden states and does not accept any labels. You should choose a model with a head " "suitable for your task like any of the `AutoModelForXxx` listed at " "https://huggingface.co/docs/transformers/model_doc/auto." ) if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: self.is_model_parallel = True else: self.is_model_parallel = False if getattr(model, "hf_device_map", None) is not None: devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] if len(devices) > 1: self.is_model_parallel = True else: self.is_model_parallel = self.args.device != torch.device(devices[0]) # warn users logger.info( "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" " to `True` to avoid any unexpected behavior such as device placement mismatching." ) # At this stage the model is already loaded if getattr(model, "is_quantized", False): if getattr(model, "_is_quantized_training_enabled", False): logger.info( "The model is loaded in 8-bit precision. To train this model you need to add additional modules" " inside the model such as adapters using `peft` library and freeze the model weights. Please" " check " " the examples in https://github.com/huggingface/peft for more details." ) else: raise ValueError( "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " ) # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if len(args.fsdp) > 0: raise ValueError( "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." ) if args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using sharded DDP only works in distributed training.") elif not is_fairscale_available(): raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: raise ImportError( "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." ) elif ShardedDDPOption.SIMPLE in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.SIMPLE elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 self.fsdp = None if len(args.fsdp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") # dep_version_check("torch>=1.12.0") # Would have to update setup.py with torch>=1.12.0 # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 # below is the current alternative. if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): raise ValueError("FSDP requires PyTorch >= 1.12.0") from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy if FSDPOption.FULL_SHARD in args.fsdp: self.fsdp = ShardingStrategy.FULL_SHARD elif FSDPOption.SHARD_GRAD_OP in args.fsdp: self.fsdp = ShardingStrategy.SHARD_GRAD_OP elif FSDPOption.NO_SHARD in args.fsdp: self.fsdp = ShardingStrategy.NO_SHARD self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( "backward_prefetch", [] ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST self.forward_prefetch = False if self.args.fsdp_config.get("forward_prefect", False): self.forward_prefetch = True self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True # one place to sort out whether to place the model on device or not # postpone switching model to cuda when: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 4. Sharded DDP - same as MP # 5. FSDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) or self.is_fsdp_enabled ): self.place_model_on_device = False default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer # Quantized models doesn't support `.to` operation. if self.place_model_on_device and not getattr(model, "is_quantized", False): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs if self.is_model_parallel: self.args._n_gpu = 1 # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model self.compute_metrics = compute_metrics self.preprocess_logits_for_metrics = preprocess_logits_for_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): raise RuntimeError( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) if is_torch_tpu_available() and self.optimizer is not None: for param in self.model.parameters(): model_device = param.device break for param_group in self.optimizer.param_groups: if len(param_group["params"]) > 0: optimizer_device = param_group["params"][0].device break if model_device != optimizer_device: raise ValueError( "The model and the optimizer parameters are not on the same device, which probably means you" " created an optimizer around your model **before** putting on the device and passing it to the" " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False # Create clone of distant repo and output directory if needed if self.args.push_to_hub: self.init_git_repo(at_init=True) # In case of pull, we need to make sure every process has the latest. if is_torch_tpu_available(): xm.rendezvous("init git repo") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: raise ValueError( "The train_dataset does not implement __len__, max_steps has to be specified. " "The number of steps needs to be known in advance for the learning rate scheduler." ) if ( train_dataset is not None and isinstance(train_dataset, torch.utils.data.IterableDataset) and args.group_by_length ): raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") self._signature_columns = None # Mixed precision setup self.use_apex = False self.use_cuda_amp = False self.use_cpu_amp = False # Mixed precision setup for SageMaker Model Parallel if is_sagemaker_mp_enabled(): # BF16 + model parallelism in SageMaker: currently not supported, raise an error if args.bf16: raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") if IS_SAGEMAKER_MP_POST_1_10: # When there's mismatch between SMP config and trainer argument, use SMP config as truth if args.fp16 != smp.state.cfg.fp16: logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," f"but FP16 provided in trainer argument is {args.fp16}," f"setting to {smp.state.cfg.fp16}" ) args.fp16 = smp.state.cfg.fp16 else: # smp < 1.10 does not support fp16 in trainer. if hasattr(smp.state.cfg, "fp16"): logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) if (args.fp16 or args.bf16) and self.sharded_ddp is not None: if args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: args.half_precision_backend = "cpu_amp" else: args.half_precision_backend = "cuda_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision if self.sharded_ddp is not None: if args.half_precision_backend == "cuda_amp": self.use_cuda_amp = True self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 # bf16 does not need grad scaling self.do_grad_scaling = self.amp_dtype == torch.float16 if self.do_grad_scaling: if self.sharded_ddp is not None: self.scaler = ShardedGradScaler() elif self.fsdp is not None: from torch.distributed.fsdp.sharded_grad_scaler import ( ShardedGradScaler as FSDPShardedGradScaler, ) self.scaler = FSDPShardedGradScaler() elif is_torch_tpu_available(): from torch_xla.amp import GradScaler self.scaler = GradScaler() else: self.scaler = torch.cuda.amp.GradScaler() elif args.half_precision_backend == "cpu_amp": self.use_cpu_amp = True self.amp_dtype = torch.bfloat16 elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( "Using FP16 with APEX but APEX is not installed, please refer to" " https://www.github.com/nvidia/apex." ) self.use_apex = True # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. if ( is_sagemaker_mp_enabled() and self.use_cuda_amp and args.max_grad_norm is not None and args.max_grad_norm > 0 ): raise ValueError( "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " "along 'max_grad_norm': 0 in your hyperparameters." ) # Label smoothing if self.args.label_smoothing_factor != 0: self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) else: self.label_smoother = None self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged self.current_flos = 0 self.hp_search_backend = None self.use_tune_checkpoints = False default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.can_return_loss = can_return_loss(self.model.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) # Internal variables to help with automatic batch size reduction self._train_batch_size = args.train_batch_size self._created_lr_scheduler = False # very last self._memory_tracker.stop_and_update_metrics() # torch.compile if args.torch_compile and not is_torch_compile_available(): raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") def add_callback(self, callback): """ Add a callback to the current list of [`~transformer.TrainerCallback`]. Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will instantiate a member of that class. """ self.callback_handler.add_callback(callback) def pop_callback(self, callback): """ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. If the callback is not found, returns `None` (and no error is raised). Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will pop the first member of that class found in the list of callbacks. Returns: [`~transformer.TrainerCallback`]: The callback removed, if found. """ return self.callback_handler.pop_callback(callback) def remove_callback(self, callback): """ Remove a callback from the current list of [`~transformer.TrainerCallback`]. Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will remove the first member of that class found in the list of callbacks. """ self.callback_handler.remove_callback(callback) def _move_model_to_device(self, model, device): model = model.to(device) # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): model.tie_weights() def _set_signature_columns_if_needed(self): if self._signature_columns is None: # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return dataset self._set_signature_columns_if_needed() signature_columns = self._signature_columns ignored_columns = list(set(dataset.column_names) - set(signature_columns)) if len(ignored_columns) > 0: dset_description = "" if description is None else f"in the {description} set" logger.info( f"The following columns {dset_description} don't have a corresponding argument in " f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " " you can safely ignore this message." ) columns = [k for k in signature_columns if k in dataset.column_names] if version.parse(datasets.__version__) < version.parse("1.4.0"): dataset.set_format( type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] ) return dataset else: return dataset.remove_columns(ignored_columns) def _get_collator_with_removed_columns( self, data_collator: Callable, description: Optional[str] = None ) -> Callable: """Wrap the data collator in a callable removing unused columns.""" if not self.args.remove_unused_columns: return data_collator self._set_signature_columns_if_needed() signature_columns = self._signature_columns remove_columns_collator = RemoveColumnsCollator( data_collator=data_collator, signature_columns=signature_columns, logger=logger, description=description, model_name=self.model.__class__.__name__, ) return remove_columns_collator def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): lengths = ( self.train_dataset[self.args.length_column_name] if self.args.length_column_name in self.train_dataset.column_names else None ) else: lengths = None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, dataset=self.train_dataset, lengths=lengths, model_input_name=model_input_name, ) else: return RandomSampler(self.train_dataset) def get_train_dataloader(self) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def get_train_dataloaderd2(self) -> DataLoader: return build_train_dataloader(self.cfg,tokenizer=self.data_loader_args[0],data_args=self.data_loader_args[1],preprocess=self.data_loader_args[2] ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: # Deprecated code if self.args.use_legacy_prediction_loop: if is_torch_tpu_available(): return SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) elif is_sagemaker_mp_enabled(): return SequentialDistributedSampler( eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank(), batch_size=self.args.per_device_eval_batch_size, ) else: return SequentialSampler(eval_dataset) if self.args.world_size <= 1: return SequentialSampler(eval_dataset) else: return None def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: eval_dataset (`torch.utils.data.Dataset`, *optional*): If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset data_collator = self.data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ Returns the test [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: test_dataset (`torch.utils.data.Dataset`, *optional*): The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ data_collator = self.data_collator if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="test") dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(test_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last # We use the same batch_size as for eval. return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or `create_scheduler`) in a subclass. """ self.create_optimizer() if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer optimizer = self.optimizer.optimizer else: optimizer = self.optimizer self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] # optimizer_grouped_parameters = [ # { # "params": [ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) # ], # "weight_decay": self.args.weight_decay, # }, # { # "params": [ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) # ], # "weight_decay": 0.0, # }, # ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) def match_name_keywords(n, name_keywords): out = False for b in name_keywords: if b in n: out = True break return out lr_backbone_names=['backbone'] lr_linear_proj_names=['reference_points', 'sampling_offsets'] seg_model_names=['seg_model'] optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names)) and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names)) and p.requires_grad], "lr": optimizer_kwargs['lr'], }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr']*0.1, }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr']*0.1, }, ] if not getattr(self.args, 'tune_mm_mlp_adapter', False): optimizer_grouped_parameters[0] = { "params": [p for n, p in opt_model.named_parameters() if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names)) and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names)) and match_name_keywords(n,seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr'], } llm_dict= { "params": [p for n, p in opt_model.named_parameters() if n.startswith('model.') and p.requires_grad], "lr": 2e-5, } optimizer_grouped_parameters.append(llm_dict) if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer = OSS( params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs, ) else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer @staticmethod def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters based on the training arguments. Args: args (`transformers.training_args.TrainingArguments`): The training arguments for the training session. """ # parse args.optim_args optim_args = {} if args.optim_args: for mapping in args.optim_args.replace(" ", "").split(","): key, value = mapping.split("=") optim_args[key] = value optimizer_kwargs = {"lr": args.learning_rate} adam_kwargs = { "betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon, } if args.optim == OptimizerNames.ADAFACTOR: optimizer_cls = Adafactor optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) elif args.optim == OptimizerNames.ADAMW_HF: from .optimization import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: from torch.optim import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: optimizer_kwargs.update({"fused": True}) elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: try: from torch_xla.amp.syncfree import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: try: from apex.optimizers import FusedAdam optimizer_cls = FusedAdam optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") elif args.optim in [ OptimizerNames.ADAMW_BNB, OptimizerNames.ADAMW_8BIT, OptimizerNames.PAGED_ADAMW, OptimizerNames.PAGED_ADAMW_8BIT, OptimizerNames.LION, OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, OptimizerNames.PAGED_LION_8BIT, ]: try: from bitsandbytes.optim import AdamW, Lion is_paged = False optim_bits = 32 optimizer_cls = None additional_optim_kwargs = adam_kwargs if "paged" in args.optim: is_paged = True if "8bit" in args.optim: optim_bits = 8 if "adam" in args.optim: optimizer_cls = AdamW elif "lion" in args.optim: optimizer_cls = Lion additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} optimizer_kwargs.update(additional_optim_kwargs) optimizer_kwargs.update(bnb_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") elif args.optim == OptimizerNames.ADAMW_BNB: try: from bitsandbytes.optim import Adam8bit optimizer_cls = Adam8bit optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: try: from torchdistx.optimizers import AnyPrecisionAdamW optimizer_cls = AnyPrecisionAdamW optimizer_kwargs.update(adam_kwargs) # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. optimizer_kwargs.update( { "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), "compensation_buffer_dtype": getattr( torch, optim_args.get("compensation_buffer_dtype", "bfloat16") ), } ) except ImportError: raise ValueError("Please install https://github.com/pytorch/torchdistx") elif args.optim == OptimizerNames.SGD: optimizer_cls = torch.optim.SGD elif args.optim == OptimizerNames.ADAGRAD: optimizer_cls = torch.optim.Adagrad else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. Args: num_training_steps (int): The number of training steps to do. """ if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, optimizer=self.optimizer if optimizer is None else optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) self._created_lr_scheduler = True return self.lr_scheduler def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When dataloader.dataset does not exist or has no length, estimates as best it can """ try: dataset = dataloader.dataset # Special case for IterableDatasetShard, we need to dig deeper if isinstance(dataset, IterableDatasetShard): return len(dataloader.dataset.dataset) return len(dataloader.dataset) except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader return len(dataloader) * self.args.per_device_train_batch_size def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """HP search setup code""" self._trial = trial if self.hp_search_backend is None or trial is None: return if self.hp_search_backend == HPSearchBackend.OPTUNA: params = self.hp_space(trial) elif self.hp_search_backend == HPSearchBackend.RAY: params = trial params.pop("wandb", None) elif self.hp_search_backend == HPSearchBackend.SIGOPT: params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} elif self.hp_search_backend == HPSearchBackend.WANDB: params = trial for key, value in params.items(): if not hasattr(self.args, key): logger.warning( f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" " `TrainingArguments`." ) continue old_attr = getattr(self.args, key, None) # Casting value to the proper type if old_attr is not None: value = type(old_attr)(value) setattr(self.args, key, value) if self.hp_search_backend == HPSearchBackend.OPTUNA: logger.info(f"Trial: {trial.params}") if self.hp_search_backend == HPSearchBackend.SIGOPT: logger.info(f"SigOpt Assignments: {trial.assignments}") if self.hp_search_backend == HPSearchBackend.WANDB: logger.info(f"W&B Sweep parameters: {trial}") if self.is_deepspeed_enabled: if self.args.deepspeed is None: raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") # Rebuild the deepspeed config to reflect the updated training parameters from accelerate.utils import DeepSpeedPlugin from transformers.deepspeed import HfTrainerDeepSpeedConfig self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) self.create_accelerator_and_postprocess() def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): if self.hp_search_backend is None or trial is None: return self.objective = self.compute_objective(metrics.copy()) if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna trial.report(self.objective, step) if trial.should_prune(): self.callback_handler.on_train_end(self.args, self.state, self.control) raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: from ray import tune if self.control.should_save: self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) def _tune_save_checkpoint(self): from ray import tune if not self.use_tune_checkpoints: return with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir, _internal_call=True) if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) def call_model_init(self, trial=None): model_init_argcount = number_of_arguments(self.model_init) if model_init_argcount == 0: model = self.model_init() elif model_init_argcount == 1: model = self.model_init(trial) else: raise RuntimeError("model_init should have 0 or 1 argument.") if model is None: raise RuntimeError("model_init should not return None.") return model def torch_jit_model_eval(self, model, dataloader, training=False): if not training: if dataloader is None: logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") return model example_batch = next(iter(dataloader)) example_batch = self._prepare_inputs(example_batch) try: jit_model = copy.copy(model) jit_model.eval() original_forward = jit_model.__dict__.pop("_original_forward", None) # remove mixed precision hooks from the model if original_forward: jit_model.forward = original_forward with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): if isinstance(example_batch, dict): jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) else: jit_model = torch.jit.trace( jit_model, example_kwarg_inputs={key: example_batch[key] for key in example_batch}, strict=False, ) else: jit_inputs = [] for key in example_batch: example_tensor = torch.ones_like(example_batch[key]) jit_inputs.append(example_tensor) jit_inputs = tuple(jit_inputs) jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) jit_model = torch.jit.freeze(jit_model) with torch.no_grad(): jit_model(**example_batch) jit_model(**example_batch) model = jit_model self.use_cpu_amp = False self.use_cuda_amp = False except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: logger.warning(f"failed to use PyTorch jit mode due to: {e}.") return model def ipex_optimize_model(self, model, training=False, dtype=torch.float32): if not is_ipex_available(): raise ImportError( "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" " to https://github.com/intel/intel-extension-for-pytorch." ) import intel_extension_for_pytorch as ipex if not training: model.eval() dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) else: if not model.training: model.train() model, self.optimizer = ipex.optimize( model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" ) return model def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training (should be after apex fp16 initialization) if self.sharded_ddp is not None: # Sharded DDP! if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: model = auto_wrap(model) self.model = model = FullyShardedDDP( model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP elif self.fsdp is not None and self.args.fsdp_config["xla"]: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: if is_torch_neuroncore_available(): return model kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if self.args.ddp_broadcast_buffers is not None: kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None, ignore_keys_for_eval: Optional[List[str]] = None, **kwargs, ): """ Main training entry point. Args: resume_from_checkpoint (`str` or `bool`, *optional*): If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. ignore_keys_for_eval (`List[str]`, *optional*) A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments used to hide deprecated arguments """ if resume_from_checkpoint is False: resume_from_checkpoint = None # memory metrics - must set up as early as possible self._memory_tracker.start() args = self.args self.is_in_train = True # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " "instead.", FutureWarning, ) if len(kwargs) > 0: raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.model = self.call_model_init(trial) model_reloaded = True # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: resume_from_checkpoint = get_last_checkpoint(args.output_dir) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: if self.place_model_on_device: self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model inner_training_loop = find_executable_batch_size( self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) return inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): self.accelerator.free_memory() self._train_batch_size = batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloaderd2() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if args.max_steps<0: args.max_steps=100 if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps and args.logging_steps < 1: args.logging_steps = math.ceil(max_steps * args.logging_steps) if args.eval_steps and args.eval_steps < 1: args.eval_steps = math.ceil(max_steps * args.eval_steps) if args.save_steps and args.save_steps < 1: args.save_steps = math.ceil(max_steps * args.save_steps) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torch.distributed.launch)." ) else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = ( self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() or self.fsdp is not None ) # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None self._created_lr_scheduler = False if self.is_deepspeed_enabled: self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() model = self._wrap_model(self.model_wrapped) if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() if hasattr(self.lr_scheduler, "step"): if self.use_apex: model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # deepspeed ckpt loading if resume_from_checkpoint is not None and self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint, load_optimizer_states=self.args.load_optimizer_states, load_lr_scheduler_states=self.args.load_lr_scheduler_states) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader if self.hp_name is not None and self._trial is not None: # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial # parameter to Train when using DDP. self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) else: self.state.trial_params = None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss_ = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): for _ in train_dataloader: break total_batched_samples = 0 tr_loss = dict() for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 # if steps_trained_in_current_epoch > 0: # epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) # steps_skipped = steps_trained_in_current_epoch # steps_trained_in_current_epoch = 0 # rng_to_sync = True step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch =0 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(steps_trained_in_current_epoch) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) if len(tr_loss)==0: tr_loss={k:tr_loss_.clone() for k in tr_loss_step.keys()} for k, loss in tr_loss.items(): if ( args.logging_nan_inf_filter and not is_torch_tpu_available() and (torch.isnan(tr_loss_step[k]) or torch.isinf(tr_loss_step[k])) ): # if loss is nan or inf simply add the average of previous logged losses tr_loss[k] += loss / (1 + self.state.global_step - self._globalstep_last_logged) else: tr_loss[k] += tr_loss_step[k] # if ( # args.logging_nan_inf_filter # and not is_torch_tpu_available() # and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) # ): # # if loss is nan or inf simply add the average of previous logged losses # tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) # else: # tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) is_last_step_and_steps_less_than_grad_acc = ( steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ) if ( total_batched_samples % args.gradient_accumulation_steps == 0 or # last step in epoch but step is always smaller than gradient_accumulation_steps is_last_step_and_steps_less_than_grad_acc ): # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered # in accelerate. So, explicitly enable sync gradients to True in that case. if is_last_step_and_steps_less_than_grad_acc or ( version.parse(accelerate_version) <= version.parse("0.20.3") ): self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if self.do_grad_scaling: # Reduce gradients first for XLA if is_torch_tpu_available(): gradients = xm._fetch_gradients(self.optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) # AMP: gradients need unscaling self.scaler.unscale_(self.optimizer) if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm, ) else: self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) # Optimizer step optimizer_was_run = True if is_torch_tpu_available(): if self.do_grad_scaling: self.scaler.step(self.optimizer) self.scaler.update() else: # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step self.optimizer.step() elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler.get_scale() optimizer_was_run = scale_before <= scale_after else: self.optimizer.step() optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: break if step < 0: logger.warning( "There seems to be not a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() self._load_best_model() # add remaining tr_loss # self._total_loss_scalar += tr_loss.item() self._total_loss_scalar += tr_loss['loss_total'].item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if checkpoint != self.state.best_model_checkpoint: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) return TrainOutput(self.state.global_step, train_loss, metrics) def _get_output_dir(self, trial): if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend == HPSearchBackend.OPTUNA: run_id = trial.number elif self.hp_search_backend == HPSearchBackend.RAY: from ray import tune run_id = tune.get_trial_id() elif self.hp_search_backend == HPSearchBackend.SIGOPT: run_id = trial.id elif self.hp_search_backend == HPSearchBackend.WANDB: import wandb run_id = wandb.run.id run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" run_dir = os.path.join(self.args.output_dir, run_name) else: run_dir = self.args.output_dir return run_dir def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if model is None: model = self.model config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) if not any( os.path.isfile(f) for f in [ weights_file, safe_weights_file, weights_index_file, safe_weights_index_file, adapter_weights_file, adapter_safe_weights_file, ] ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}.") if os.path.isfile(config_file): config = PretrainedConfig.from_json_file(config_file) checkpoint_version = config.transformers_version if checkpoint_version is not None and checkpoint_version != __version__: logger.warning( f"You are resuming training from a checkpoint trained with {checkpoint_version} of " f"Transformers but your current version is {__version__}. This is not recommended and could " "yield to errors or unwanted behaviors." ) if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. # Checkpoint must have been saved with the new smp api. smp.resume_from_checkpoint( path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False ) else: # If the 'user_content.pt' file does NOT exist, load with the old smp api. # Checkpoint must have been saved with the old smp api. if hasattr(self.args, "fp16") and self.args.fp16 is True: logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) state_dict = torch.load(weights_file, map_location="cpu") # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) # release memory del state_dict elif self.is_fsdp_enabled: load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint) else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(safe_weights_file): state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") else: state_dict = torch.load(weights_file, map_location="cpu") # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) # release memory del state_dict self._issue_warnings_after_load(load_result) # Load adapters following PR # 24096 elif is_peft_available() and isinstance(model, PeftModel): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if os.path.exists(resume_from_checkpoint): model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " "Check some examples here: https://github.com/huggingface/peft/issues/96" ) else: logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") else: # We load the sharded checkpoint load_result = load_sharded_checkpoint( model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if ( os.path.exists(best_model_path) or os.path.exists(best_safe_model_path) or os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path) ): if self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) else: has_been_loaded = True if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. # Checkpoint must have been saved with the new smp api. smp.resume_from_checkpoint( path=self.state.best_model_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False, ) else: # If the 'user_content.pt' file does NOT exist, load with the old smp api. # Checkpoint must have been saved with the old smp api. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: state_dict = torch.load(best_model_path, map_location="cpu") state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) elif self.is_fsdp_enabled: load_fsdp_model( self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint ) else: if is_peft_available() and isinstance(model, PeftModel): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys load_result = _IncompatibleKeys([], []) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " "Check some examples here: https://github.com/huggingface/peft/issues/96" ) has_been_loaded = False else: logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: state_dict = torch.load(best_model_path, map_location="cpu") # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): load_result = load_sharded_checkpoint( model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) else: logger.warning( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " "on multiple nodes, you should activate `--save_on_each_node`." ) def _issue_warnings_after_load(self, load_result): if len(load_result.missing_keys) != 0: if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( self.model._keys_to_ignore_on_save ): self.model.tie_weights() else: logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") if len(load_result.unexpected_keys) != 0: logger.warning( f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." ) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log: if is_torch_tpu_available(): xm.mark_step() logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes # tr_loss_scalar = self._nested_gather(tr_loss).mean().item() tr_loss_scalar = {k: self._nested_gather(tr_loss[k]).mean().item() for k in tr_loss.keys()} # reset tr_loss to zero for _,loss in tr_loss.items(): loss -= loss # tr_loss -= tr_loss # logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) for k,loss in tr_loss_scalar.items(): logs[k]=round(loss / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar['loss_total'] self._globalstep_last_logged = self.state.global_step self.store_flos() self.log(logs) metrics = None if self.control.should_evaluate: if isinstance(self.eval_dataset, dict): metrics = {} for eval_dataset_name, eval_dataset in self.eval_dataset.items(): dataset_metrics = self.evaluate( eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval, metric_key_prefix=f"eval_{eval_dataset_name}", ) metrics.update(dataset_metrics) else: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) # Run delayed LR scheduler now that metrics are populated if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _load_rng_state(self, checkpoint): # Load RNG states from `checkpoint` if checkpoint is None: return if self.args.world_size > 1: process_index = self.args.process_index rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") if not os.path.isfile(rng_file): logger.info( f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " "wasn't launched in a distributed fashion, reproducibility is not guaranteed." ) return else: rng_file = os.path.join(checkpoint, "rng_state.pth") if not os.path.isfile(rng_file): logger.info( "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " "fashion, reproducibility is not guaranteed." ) return checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) else: try: torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) except Exception as e: logger.info( f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" "\nThis won't yield the same results as if the training had not been interrupted." ) if is_torch_tpu_available(): xm.set_rng_state(checkpoint_rng_state["xla"]) def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is None and trial is None: self.store_flos() run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer.consolidate_state_dict() if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: save_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir ) else: # FSDP has a different interface for saving optimizer states. # Needs to be called on all ranks to gather all states. # full_optim_state_dict will be deprecated after Pytorch 2.2! full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) smp.barrier() if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: smp.save( opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=True, v3=smp.state.cfg.shard_optimizer_state, ) if self.args.should_save: with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched if self.fsdp and not self.is_fsdp_enabled: torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) else: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if ( self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric) ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir # Save the Trainer state if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) # Save RNG state in non-distributed training rng_states = { "python": random.getstate(), "numpy": np.random.get_state(), "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: rng_states["cuda"] = torch.cuda.random.get_rng_state() if is_torch_tpu_available(): rng_states["xla"] = xm.get_rng_state() # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) if self.args.world_size <= 1: torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) else: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) if self.args.push_to_hub: self._push_from_checkpoint(output_dir) # Maybe delete some older checkpoints. if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: return if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") if is_sagemaker_mp_enabled() else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): # Optimizer checkpoint was saved with smp >= 1.10 def opt_load_hook(mod, opt): opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) else: # Optimizer checkpoint was saved with smp < 1.10 def opt_load_hook(mod, opt): if IS_SAGEMAKER_MP_POST_1_10: opt.load_state_dict( smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) ) else: opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) self.model_wrapped.register_post_step_hook(opt_load_hook) else: # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = self.args.device if self.args.world_size > 1 else "cpu" if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: load_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, checkpoint, ) else: full_osd = None # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it if self.args.process_index == 0: full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) # call scatter_full_optim_state_dict on all ranks sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) self.optimizer.load_state_dict(sharded_osd) else: self.optimizer.load_state_dict( torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, n_trials: int = 20, direction: str = "minimize", backend: Optional[Union["str", HPSearchBackend]] = None, hp_name: Optional[Callable[["optuna.Trial"], str]] = None, **kwargs, ) -> BestRun: """ Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, the sum of all metrics otherwise. To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom optimizer/scheduler. Args: hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): A function that defines the hyperparameter search space. Will default to [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` method. Will default to [`~trainer_utils.default_compute_objective`]. n_trials (`int`, *optional*, defaults to 100): The number of trial runs to test. direction (`str`, *optional*, defaults to `"minimize"`): Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending on which one is installed. If all are installed, will default to optuna. hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): A function that defines the trial/run name. Will default to None. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more information see: - the documentation of [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) Returns: [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in `run_summary` attribute for Ray backend. """ if backend is None: backend = default_hp_search_backend() backend = HPSearchBackend(backend) backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() backend_obj.ensure_available() self.hp_search_backend = backend if self.model_init is None: raise RuntimeError( "To use hyperparameter search, you need to pass your model through a model_init function." ) self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective best_run = backend_obj.run(self, n_trials, direction, **kwargs) self.hp_search_backend = None return best_run def log(self, logs: Dict[str, float]) -> None: """ Log `logs` on the various objects watching training. Subclass and override this method to inject custom behavior. Args: logs (`Dict[str, float]`): The values to log. """ if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 2) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: """ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. """ if isinstance(data, Mapping): return type(data)({k: self._prepare_input(v) for k, v in data.items()}) elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.args.device} if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): # NLP models inputs are int/uint and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) return data.to(**kwargs) return data def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ inputs = self._prepare_input(inputs) if len(inputs) == 0: raise ValueError( "The batch received was empty, your model won't be able to train on it. Double-check that your " f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." ) if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past return inputs def compute_loss_context_manager(self): """ A helper wrapper to group together context managers. """ return self.autocast_smart_context_manager() def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): """ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired arguments, depending on the situation. """ if self.use_cuda_amp or self.use_cpu_amp: ctx_manager = ( torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) if self.use_cpu_amp else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) ) else: ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() return ctx_manager def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. Return: `torch.Tensor`: The tensor with training loss on this batch. """ model.train() inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: for k, ls in loss.items(): loss[k] = loss[k].mean() # mean() to average on multi-gpu parallel training if self.do_grad_scaling: self.scaler.scale(loss['loss_total']).backward() elif self.use_apex: with amp.scale_loss(loss['loss_total'], self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss['loss_total']) # return loss.detach() / self.args.gradient_accumulation_steps return {k:v.detach()/self.args.gradient_accumulation_steps for k,v in loss.items()} def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: labels = None outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if labels is not None: if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process. """ return self.args.local_process_index == 0 def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be `True` for one process). """ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global # process index. if is_sagemaker_mp_enabled(): return smp.rank() == 0 else: return self.args.process_index == 0 def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): """ Will save the model, so you can reload it using `from_pretrained()`. Will only save from the main process. """ if output_dir is None: output_dir = self.args.output_dir if is_torch_tpu_available(): self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. os.makedirs(output_dir, exist_ok=True) state_dict = self.model_wrapped.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() elif ( ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None or self.is_fsdp_enabled ): state_dict = self.model.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) if self.is_fsdp_enabled: save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 if version.parse(accelerate_version) <= version.parse("0.20.3"): raise ValueError("Install Accelerate from main branch") try: state_dict = self.accelerator.get_state_dict(self.deepspeed) if self.args.should_save: self._save(output_dir, state_dict=state_dict) except ValueError: logger.warning( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: self._save(output_dir) # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info(f"Saving model checkpoint to {output_dir}") if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if not isinstance(self.model, PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel): unwrap_model(self.model).save_pretrained( output_dir, is_main_process=self.args.should_save, state_dict=self.model.state_dict(), save_function=xm.save, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() if isinstance(unwrap_model(self.model), supported_classes): unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def store_flos(self): # Storing the number of floating-point operations that went into the model if self.args.parallel_mode == ParallelMode.DISTRIBUTED: self.state.total_flos += ( distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() ) self.current_flos = 0 else: self.state.total_flos += self.current_flos self.current_flos = 0 def _sorted_checkpoints( self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False ) -> List[str]: ordering_and_checkpoint_path = [] glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] for path in glob_checkpoints: if use_mtime: ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) else: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match is not None and regex_match.groups() is not None: ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] # Make sure we don't delete the best model. if self.state.best_model_checkpoint is not None: best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) for i in range(best_model_index, len(checkpoints_sorted) - 2): checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Check if we should delete older checkpoint(s) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) if len(checkpoints_sorted) <= self.args.save_total_limit: return # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which # we don't do to allow resuming. save_total_limit = self.args.save_total_limit if ( self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1 and checkpoints_sorted[-1] != self.state.best_model_checkpoint ): save_total_limit = 2 number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init `compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (`Dataset`, *optional*): Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` method. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is "eval" (default) Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ # memory metrics - must set up as early as possible self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop output = eval_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, ) total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, start_time, num_samples=output.num_samples, num_steps=math.ceil(output.num_samples / total_batch_size), ) ) self.log(output.metrics) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return output.metrics def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in `evaluate()`. Args: test_dataset (`Dataset`): Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` method are automatically removed. Has to implement the method `__len__` ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"test"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "test_bleu" if the prefix is "test" (default) If your predictions or labels have different sequence length (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100. Returns: *NamedTuple* A namedtuple with the following keys: - predictions (`np.ndarray`): The predictions on `test_dataset`. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ # memory metrics - must set up as early as possible self._memory_tracker.start() test_dataloader = self.get_test_dataloader(test_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop output = eval_loop( test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, start_time, num_samples=output.num_samples, num_steps=math.ceil(output.num_samples / total_batch_size), ) ) self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) def evaluation_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: """ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ args = self.args prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.deepspeed is None: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: if args.fp16_full_eval: model = model.to(dtype=torch.float16, device=args.device) elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = self.args.eval_batch_size logger.info(f"***** Running {description} *****") if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: logger.info(" Num examples: Unknown") logger.info(f" Batch size = {batch_size}") model.eval() self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. eval_dataset = getattr(dataloader, "dataset", None) if args.past_index >= 0: self._past = None # Initialize containers # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) losses_host = None preds_host = None labels_host = None inputs_host = None # losses/preds/labels on CPU (final containers) all_losses = None all_preds = None all_labels = None all_inputs = None # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 # Main evaluation loop for step, inputs in enumerate(dataloader): # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: observed_num_examples += observed_batch_size # For batch samplers, batch_size is not known by the dataloader in advance. if batch_size is None: batch_size = observed_batch_size # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() # Update containers on host if loss is not None: losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_host = ( inputs_decode if inputs_host is None else nested_concat(inputs_host, inputs_decode, padding_index=-100) ) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.accelerator.gather_for_metrics((logits)) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels = self.accelerator.gather_for_metrics((labels)) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) if inputs_host is not None: inputs_decode = nested_numpify(inputs_host) all_inputs = ( inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = ( labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) ) # Set back to None to begin a new accumulation losses_host, preds_host, inputs_host, labels_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) if inputs_host is not None: inputs_decode = nested_numpify(inputs_host) all_inputs = ( inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) # Number of samples if has_length(eval_dataset): num_samples = len(eval_dataset) # The instance check is weird and does not actually check for the type, but whether the dataset has the right # methods. Therefore we need to make sure it also has the attribute. elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: num_samples = eval_dataset.num_examples else: if has_length(dataloader): num_samples = self.num_examples(dataloader) else: # both len(dataloader.dataset) and len(dataloader) fail num_samples = observed_num_examples if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples # Metrics! if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) else: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) def _nested_gather(self, tensors, name=None): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return if is_torch_tpu_available(): if name is None: name = "nested_gather" tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( self.args.distributed_state is None and self.args.local_rank != -1 ): tensors = distributed_concat(tensors) return tensors def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to evaluate. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): Whether or not to return the loss only. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. Return: Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) # For CLIP-like models capable of returning loss values. # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` # is `True` in `model.forward`. return_loss = inputs.get("return_loss", None) if return_loss is None: return_loss = self.can_return_loss loss_without_labels = True if len(self.label_names) == 0 and return_loss else False inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None with torch.no_grad(): if is_sagemaker_mp_enabled(): raw_outputs = smp_forward_only(model, inputs) if has_labels or loss_without_labels: if isinstance(raw_outputs, dict): loss_mb = raw_outputs["loss"] logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) else: loss_mb = raw_outputs[0] logits_mb = raw_outputs[1:] loss = loss_mb.reduce_mean().detach().cpu() logits = smp_nested_concat(logits_mb) else: loss = None if isinstance(raw_outputs, dict): logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) else: logits_mb = raw_outputs logits = smp_nested_concat(logits_mb) else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs[1:] else: loss = None with self.compute_loss_context_manager(): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: logits = outputs # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] return (loss, logits, labels) def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): """ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point operations for every backward + forward pass. If using another model, either implement such a method in the model or subclass and override this method. Args: inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. Returns: `int`: The number of floating-point operations. """ if hasattr(self.model, "floating_point_ops"): return self.model.floating_point_ops(inputs) else: return 0 def init_git_repo(self, at_init: bool = False): """ Initializes a git repo in `self.args.hub_model_id`. Args: at_init (`bool`, *optional*, defaults to `False`): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ if not self.is_world_process_zero(): return if self.args.hub_model_id is None: repo_name = Path(self.args.output_dir).absolute().name else: repo_name = self.args.hub_model_id if "/" not in repo_name: repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) # Make sure the repo exists. create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) try: self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) except EnvironmentError: if self.args.overwrite_output_dir and at_init: # Try again after wiping output_dir shutil.rmtree(self.args.output_dir) self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) else: raise self.repo.git_pull() # By default, ignore the checkpoint folders if ( not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS ): with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: writer.writelines(["checkpoint-*/"]) # Add "*.sagemaker" to .gitignore if using SageMaker if os.environ.get("SM_TRAINING_ENV"): self._add_sm_patterns_to_gitignore() self.push_in_progress = None def create_model_card( self, language: Optional[str] = None, license: Optional[str] = None, tags: Union[str, List[str], None] = None, model_name: Optional[str] = None, finetuned_from: Optional[str] = None, tasks: Union[str, List[str], None] = None, dataset_tags: Union[str, List[str], None] = None, dataset: Union[str, List[str], None] = None, dataset_args: Union[str, List[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: language (`str`, *optional*): The language of the model (if applicable) license (`str`, *optional*): The license of the model. Will default to the license of the pretrained model used, if the original model given to the `Trainer` comes from a repo on the Hub. tags (`str` or `List[str]`, *optional*): Some tags to be included in the metadata of the model card. model_name (`str`, *optional*): The name of the model. finetuned_from (`str`, *optional*): The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo of the original model given to the `Trainer` (if it comes from the Hub). tasks (`str` or `List[str]`, *optional*): One or several task identifiers, to be included in the metadata of the model card. dataset_tags (`str` or `List[str]`, *optional*): One or several dataset tags, to be included in the metadata of the model card. dataset (`str` or `List[str]`, *optional*): One or several dataset identifiers, to be included in the metadata of the model card. dataset_args (`str` or `List[str]`, *optional*): One or several dataset arguments, to be included in the metadata of the model card. """ if not self.is_world_process_zero(): return training_summary = TrainingSummary.from_trainer( self, language=language, license=license, tags=tags, model_name=model_name, finetuned_from=finetuned_from, tasks=tasks, dataset_tags=dataset_tags, dataset=dataset, dataset_args=dataset_args, ) model_card = training_summary.to_model_card() with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: f.write(model_card) def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: return # If we haven't finished the last push, we don't do this one. if self.push_in_progress is not None and not self.push_in_progress.is_done: return output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] if is_peft_available(): modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) try: if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Temporarily move the checkpoint just saved for the push tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a # subfolder. if os.path.isdir(tmp_checkpoint): shutil.rmtree(tmp_checkpoint) shutil.move(checkpoint_folder, tmp_checkpoint) if self.args.save_strategy == IntervalStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" else: commit_message = f"Training in progress, epoch {int(self.state.epoch)}" push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True) # Return type of `Repository.push_to_hub` is either None or a tuple. if push_work is not None: self.push_in_progress = push_work[1] except Exception as e: logger.error(f"Error when pushing to hub: {e}") finally: if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Move back the checkpoint to its place shutil.move(tmp_checkpoint, checkpoint_folder) def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: """ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. Parameters: commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to [`~Trainer.create_model_card`]. Returns: The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the commit and an object to track the progress of the commit if `blocking=True` """ # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but # it might fail. if not hasattr(self, "repo"): self.init_git_repo() model_name = kwargs.pop("model_name", None) if model_name is None and self.args.should_save: if self.args.hub_model_id is None: model_name = Path(self.args.output_dir).name else: model_name = self.args.hub_model_id.split("/")[-1] # Needs to be executed on all processes for TPU training, but will only save on the processed determined by # self.args.should_save. self.save_model(_internal_call=True) # Only push from one node. if not self.is_world_process_zero(): return # Cancel any async push in progress if blocking=True. The commits will all be pushed together. if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: self.push_in_progress._process.kill() self.push_in_progress = None git_head_commit_url = self.repo.push_to_hub( commit_message=commit_message, blocking=blocking, auto_lfs_prune=True ) # push separately the model card to be independant from the rest of the model if self.args.should_save: self.create_model_card(model_name=model_name, **kwargs) try: self.repo.push_to_hub( commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True ) except EnvironmentError as exc: logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") return git_head_commit_url # # Deprecated code # def prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: """ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ args = self.args if not has_length(dataloader): raise ValueError("dataloader must implement a working __len__") prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.deepspeed is None: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: if args.fp16_full_eval: model = model.to(dtype=torch.float16, device=args.device) elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None world_size = max(1, args.world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass # a batch size to the sampler) make_multiple_of = None if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): make_multiple_of = dataloader.sampler.batch_size preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() if args.past_index >= 0: self._past = None self.callback_handler.eval_dataloader = dataloader for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) if logits is not None: preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) if inputs_decode is not None: inputs_host = ( inputs_decode if inputs_host is None else nested_concat(inputs_host, inputs_decode, padding_index=-100) ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation losses_host, preds_host, labels_host, inputs_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) eval_loss = eval_losses_gatherer.finalize() preds = preds_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None if self.compute_metrics is not None and preds is not None and label_ids is not None: if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) if eval_loss is not None: metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) def _gather_and_numpify(self, tensors, name): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return if is_torch_tpu_available(): tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: tensors = distributed_concat(tensors) return nested_numpify(tensors) def _add_sm_patterns_to_gitignore(self) -> None: """Add SageMaker Checkpointing patterns to .gitignore file.""" # Make sure we only do this on the main process if not self.is_world_process_zero(): return patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] # Get current .gitignore content if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: current_content = f.read() else: current_content = "" # Add the patterns to .gitignore content = current_content for pattern in patterns: if pattern not in content: if content.endswith("\n"): content += pattern else: content += f"\n{pattern}" # Write the .gitignore file if it has changed if content != current_content: with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: logger.debug(f"Writing .gitignore file. Content: {content}") f.write(content) self.repo.git_add(".gitignore") # avoid race condition with git status time.sleep(0.5) if not self.repo.is_repo_clean(): self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") self.repo.git_push() def create_accelerator_and_postprocess(self): grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} if version.parse(accelerate_version) > version.parse("0.20.3"): grad_acc_kwargs["sync_with_dataloader"] = False gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) # create accelerator object self.accelerator = Accelerator( deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin ) # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params) if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: from transformers.deepspeed import HfTrainerDeepSpeedConfig ds_plugin = self.accelerator.state.deepspeed_plugin ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args) class LLaVATrainer(TrainerLLavaGD): def _save_checkpoint(self, model, trial, metrics=None): # if getattr(self.args, 'tune_mm_mlp_adapter', False): # from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR # checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" # # run_dir = self._get_output_dir(trial=trial) # output_dir = os.path.join(run_dir, checkpoint_folder) # # # Only save Adapter # keys_to_match = ['mm_projector'] # if getattr(self.args, "use_im_start_end", False) or getattr(self.args, "new_tokens", False): # keys_to_match.extend(['embed_tokens', 'embed_in','lm_head']) # # import pdb; pdb.set_trace() # weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) # # if self.args.local_rank == 0 or self.args.local_rank == -1: # self.model.config.save_pretrained(output_dir) # torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) # else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): # if getattr(self.args, 'tune_mm_mlp_adapter', False): # pass # else: super(LLaVATrainer, self)._save(output_dir, state_dict) ================================================ FILE: llava/train/llava_trainer_joint_train.py ================================================ import os import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler # from transformers import Trainer from typing import Optional from transformers.trainer import * from datasets_os import build_train_dataloader from dataclasses import dataclass, field import transformers from typing import Dict, Optional, Sequence, List def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return @dataclass class DataCollatorForSupervisedDatasetEmpty(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]): return instances # input_ids, labels = tuple([instance[key] for instance in instances] # for key in ("input_ids", "labels")) # input_ids = torch.nn.utils.rnn.pad_sequence( # input_ids, # batch_first=True, # padding_value=self.tokenizer.pad_token_id) # labels = torch.nn.utils.rnn.pad_sequence(labels, # batch_first=True, # padding_value=IGNORE_INDEX) # input_ids = input_ids[:, :self.tokenizer.model_max_length] # labels = labels[:, :self.tokenizer.model_max_length] # batch = dict( # input_ids=input_ids, # labels=labels, # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # ) # # if 'image' in instances[0]: # images = [instance['image'] for instance in instances] # if all(x is not None and x.shape == images[0].shape for x in images): # batch['images'] = torch.stack(images) # else: # batch['images'] = images # # return batch class TrainerLLavaGD(Trainer): """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers models. args ([`TrainingArguments`], *optional*): The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. data_collator (`DataCollator`, *optional*): The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will default to [`default_data_collator`] if no `tokenizer` is provided, an instance of [`DataCollatorWithPadding`] otherwise. train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. tokenizer ([`PreTrainedTokenizerBase`], *optional*): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. model_init (`Callable[[], PreTrainedModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start from a new instance of the model as given by this function. The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return a dictionary string to metric values. callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received by `compute_metrics`. Note that the labels (second parameter) will be `None` if the dataset does not have them. Important attributes: - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] subclass. - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs). - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set to `False` if model parallel or deepspeed is used, or if the default `TrainingArguments.place_model_on_device` is overridden to return `False` . - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while in `train`) """ # Those are used as methods of the Trainer in examples. def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, data_loader_args=None, cfg=None, ): self.cfg=cfg if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) self.args = args # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None self.deepspeed = None self.is_in_train = False self.data_loader_args=data_loader_args self.create_accelerator_and_postprocess() # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() # set the correct log level depending on the node log_level = args.get_process_log_level() logging.set_verbosity(log_level) # force device and distributed setup init explicitly args._setup_devices if model is None: if model_init is not None: self.model_init = model_init model = self.call_model_init() else: raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") else: if model_init is not None: warnings.warn( "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" " overwrite your model when calling the `train` method. This will become a fatal error in the next" " release.", FutureWarning, ) self.model_init = model_init if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " "computes hidden states and does not accept any labels. You should choose a model with a head " "suitable for your task like any of the `AutoModelForXxx` listed at " "https://huggingface.co/docs/transformers/model_doc/auto." ) if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: self.is_model_parallel = True else: self.is_model_parallel = False if getattr(model, "hf_device_map", None) is not None: devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] if len(devices) > 1: self.is_model_parallel = True else: self.is_model_parallel = self.args.device != torch.device(devices[0]) # warn users logger.info( "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" " to `True` to avoid any unexpected behavior such as device placement mismatching." ) # At this stage the model is already loaded if getattr(model, "is_quantized", False): if getattr(model, "_is_quantized_training_enabled", False): logger.info( "The model is loaded in 8-bit precision. To train this model you need to add additional modules" " inside the model such as adapters using `peft` library and freeze the model weights. Please" " check " " the examples in https://github.com/huggingface/peft for more details." ) else: raise ValueError( "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " ) # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if len(args.fsdp) > 0: raise ValueError( "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." ) if args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using sharded DDP only works in distributed training.") elif not is_fairscale_available(): raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: raise ImportError( "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." ) elif ShardedDDPOption.SIMPLE in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.SIMPLE elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 self.fsdp = None if len(args.fsdp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") # dep_version_check("torch>=1.12.0") # Would have to update setup.py with torch>=1.12.0 # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 # below is the current alternative. if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): raise ValueError("FSDP requires PyTorch >= 1.12.0") from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy if FSDPOption.FULL_SHARD in args.fsdp: self.fsdp = ShardingStrategy.FULL_SHARD elif FSDPOption.SHARD_GRAD_OP in args.fsdp: self.fsdp = ShardingStrategy.SHARD_GRAD_OP elif FSDPOption.NO_SHARD in args.fsdp: self.fsdp = ShardingStrategy.NO_SHARD self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( "backward_prefetch", [] ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST self.forward_prefetch = False if self.args.fsdp_config.get("forward_prefect", False): self.forward_prefetch = True self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True # one place to sort out whether to place the model on device or not # postpone switching model to cuda when: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 4. Sharded DDP - same as MP # 5. FSDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) or self.is_fsdp_enabled ): self.place_model_on_device = False default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer # Quantized models doesn't support `.to` operation. if self.place_model_on_device and not getattr(model, "is_quantized", False): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs if self.is_model_parallel: self.args._n_gpu = 1 # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model self.compute_metrics = compute_metrics self.preprocess_logits_for_metrics = preprocess_logits_for_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): raise RuntimeError( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) if is_torch_tpu_available() and self.optimizer is not None: for param in self.model.parameters(): model_device = param.device break for param_group in self.optimizer.param_groups: if len(param_group["params"]) > 0: optimizer_device = param_group["params"][0].device break if model_device != optimizer_device: raise ValueError( "The model and the optimizer parameters are not on the same device, which probably means you" " created an optimizer around your model **before** putting on the device and passing it to the" " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False # Create clone of distant repo and output directory if needed if self.args.push_to_hub: self.init_git_repo(at_init=True) # In case of pull, we need to make sure every process has the latest. if is_torch_tpu_available(): xm.rendezvous("init git repo") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: raise ValueError( "The train_dataset does not implement __len__, max_steps has to be specified. " "The number of steps needs to be known in advance for the learning rate scheduler." ) if ( train_dataset is not None and isinstance(train_dataset, torch.utils.data.IterableDataset) and args.group_by_length ): raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") self._signature_columns = None # Mixed precision setup self.use_apex = False self.use_cuda_amp = False self.use_cpu_amp = False # Mixed precision setup for SageMaker Model Parallel if is_sagemaker_mp_enabled(): # BF16 + model parallelism in SageMaker: currently not supported, raise an error if args.bf16: raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") if IS_SAGEMAKER_MP_POST_1_10: # When there's mismatch between SMP config and trainer argument, use SMP config as truth if args.fp16 != smp.state.cfg.fp16: logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," f"but FP16 provided in trainer argument is {args.fp16}," f"setting to {smp.state.cfg.fp16}" ) args.fp16 = smp.state.cfg.fp16 else: # smp < 1.10 does not support fp16 in trainer. if hasattr(smp.state.cfg, "fp16"): logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) if (args.fp16 or args.bf16) and self.sharded_ddp is not None: if args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: args.half_precision_backend = "cpu_amp" else: args.half_precision_backend = "cuda_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision if self.sharded_ddp is not None: if args.half_precision_backend == "cuda_amp": self.use_cuda_amp = True self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 # bf16 does not need grad scaling self.do_grad_scaling = self.amp_dtype == torch.float16 if self.do_grad_scaling: if self.sharded_ddp is not None: self.scaler = ShardedGradScaler() elif self.fsdp is not None: from torch.distributed.fsdp.sharded_grad_scaler import ( ShardedGradScaler as FSDPShardedGradScaler, ) self.scaler = FSDPShardedGradScaler() elif is_torch_tpu_available(): from torch_xla.amp import GradScaler self.scaler = GradScaler() else: self.scaler = torch.cuda.amp.GradScaler() elif args.half_precision_backend == "cpu_amp": self.use_cpu_amp = True self.amp_dtype = torch.bfloat16 elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( "Using FP16 with APEX but APEX is not installed, please refer to" " https://www.github.com/nvidia/apex." ) self.use_apex = True # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. if ( is_sagemaker_mp_enabled() and self.use_cuda_amp and args.max_grad_norm is not None and args.max_grad_norm > 0 ): raise ValueError( "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " "along 'max_grad_norm': 0 in your hyperparameters." ) # Label smoothing if self.args.label_smoothing_factor != 0: self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) else: self.label_smoother = None self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged self.current_flos = 0 self.hp_search_backend = None self.use_tune_checkpoints = False default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.can_return_loss = can_return_loss(self.model.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) # Internal variables to help with automatic batch size reduction self._train_batch_size = args.train_batch_size self._created_lr_scheduler = False # very last self._memory_tracker.stop_and_update_metrics() # torch.compile if args.torch_compile and not is_torch_compile_available(): raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") def add_callback(self, callback): """ Add a callback to the current list of [`~transformer.TrainerCallback`]. Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will instantiate a member of that class. """ self.callback_handler.add_callback(callback) def pop_callback(self, callback): """ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. If the callback is not found, returns `None` (and no error is raised). Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will pop the first member of that class found in the list of callbacks. Returns: [`~transformer.TrainerCallback`]: The callback removed, if found. """ return self.callback_handler.pop_callback(callback) def remove_callback(self, callback): """ Remove a callback from the current list of [`~transformer.TrainerCallback`]. Args: callback (`type` or [`~transformer.TrainerCallback`]): A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the first case, will remove the first member of that class found in the list of callbacks. """ self.callback_handler.remove_callback(callback) def _move_model_to_device(self, model, device): model = model.to(device) # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): model.tie_weights() def _set_signature_columns_if_needed(self): if self._signature_columns is None: # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return dataset self._set_signature_columns_if_needed() signature_columns = self._signature_columns ignored_columns = list(set(dataset.column_names) - set(signature_columns)) if len(ignored_columns) > 0: dset_description = "" if description is None else f"in the {description} set" logger.info( f"The following columns {dset_description} don't have a corresponding argument in " f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " " you can safely ignore this message." ) columns = [k for k in signature_columns if k in dataset.column_names] if version.parse(datasets.__version__) < version.parse("1.4.0"): dataset.set_format( type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] ) return dataset else: return dataset.remove_columns(ignored_columns) def _get_collator_with_removed_columns( self, data_collator: Callable, description: Optional[str] = None ) -> Callable: """Wrap the data collator in a callable removing unused columns.""" if not self.args.remove_unused_columns: return data_collator self._set_signature_columns_if_needed() signature_columns = self._signature_columns remove_columns_collator = RemoveColumnsCollator( data_collator=data_collator, signature_columns=signature_columns, logger=logger, description=description, model_name=self.model.__class__.__name__, ) return remove_columns_collator def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): lengths = ( self.train_dataset[self.args.length_column_name] if self.args.length_column_name in self.train_dataset.column_names else None ) else: lengths = None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, dataset=self.train_dataset, lengths=lengths, model_input_name=model_input_name, ) else: return RandomSampler(self.train_dataset) def get_train_dataloader(self) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset # data_collator = self.data_collator data_collator = DataCollatorForSupervisedDatasetEmpty(tokenizer=self.tokenizer) # datacolator= if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def get_train_dataloaderd2(self) -> DataLoader: llava_cap_loader=self.get_train_dataloader() return build_train_dataloader(self.cfg,tokenizer=self.data_loader_args[0],data_args=self.data_loader_args[1],preprocess=self.data_loader_args[2],llava_cap_loader=llava_cap_loader) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: # Deprecated code if self.args.use_legacy_prediction_loop: if is_torch_tpu_available(): return SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) elif is_sagemaker_mp_enabled(): return SequentialDistributedSampler( eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank(), batch_size=self.args.per_device_eval_batch_size, ) else: return SequentialSampler(eval_dataset) if self.args.world_size <= 1: return SequentialSampler(eval_dataset) else: return None def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: eval_dataset (`torch.utils.data.Dataset`, *optional*): If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset data_collator = self.data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ Returns the test [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: test_dataset (`torch.utils.data.Dataset`, *optional*): The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ data_collator = self.data_collator if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="test") dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(test_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last # We use the same batch_size as for eval. return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or `create_scheduler`) in a subclass. """ self.create_optimizer() if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer optimizer = self.optimizer.optimizer else: optimizer = self.optimizer self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] # optimizer_grouped_parameters = [ # { # "params": [ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) # ], # "weight_decay": self.args.weight_decay, # }, # { # "params": [ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) # ], # "weight_decay": 0.0, # }, # ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) def match_name_keywords(n, name_keywords): out = False for b in name_keywords: if b in n: out = True break return out lr_backbone_names=['backbone'] lr_linear_proj_names=['reference_points', 'sampling_offsets'] seg_model_names=['seg_model'] optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names)) and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names)) and p.requires_grad], "lr": optimizer_kwargs['lr'], }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr']*0.1, }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr']*0.1, }, ] if not getattr(self.args, 'tune_mm_mlp_adapter', False): optimizer_grouped_parameters[0] = { "params": [p for n, p in opt_model.named_parameters() if not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, seg_model_names)) and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, seg_model_names)) and match_name_keywords(n,seg_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr'], } llm_dict= { "params": [p for n, p in opt_model.named_parameters() if n.startswith('model.') and p.requires_grad], "lr": 2e-5, } optimizer_grouped_parameters.append(llm_dict) if getattr(self.args, 'train_interactive', False): interactive_model_names=['interactive_model'] optimizer_grouped_parameters_inter = [ { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, interactive_model_names) and not (match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, interactive_model_names)) and not (match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, interactive_model_names)) and p.requires_grad], "lr": optimizer_kwargs['lr'], }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_backbone_names) and match_name_keywords(n, interactive_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr'] * 0.1, }, { "params": [p for n, p in opt_model.named_parameters() if match_name_keywords(n, lr_linear_proj_names) and match_name_keywords(n, interactive_model_names) and p.requires_grad], "lr": optimizer_kwargs['lr'] * 0.1, }, ] optimizer_grouped_parameters.extend(optimizer_grouped_parameters_inter) if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer = OSS( params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs, ) else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer @staticmethod def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters based on the training arguments. Args: args (`transformers.training_args.TrainingArguments`): The training arguments for the training session. """ # parse args.optim_args optim_args = {} if args.optim_args: for mapping in args.optim_args.replace(" ", "").split(","): key, value = mapping.split("=") optim_args[key] = value optimizer_kwargs = {"lr": args.learning_rate} adam_kwargs = { "betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon, } if args.optim == OptimizerNames.ADAFACTOR: optimizer_cls = Adafactor optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) elif args.optim == OptimizerNames.ADAMW_HF: from .optimization import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: from torch.optim import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: optimizer_kwargs.update({"fused": True}) elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: try: from torch_xla.amp.syncfree import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: try: from apex.optimizers import FusedAdam optimizer_cls = FusedAdam optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") elif args.optim in [ OptimizerNames.ADAMW_BNB, OptimizerNames.ADAMW_8BIT, OptimizerNames.PAGED_ADAMW, OptimizerNames.PAGED_ADAMW_8BIT, OptimizerNames.LION, OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, OptimizerNames.PAGED_LION_8BIT, ]: try: from bitsandbytes.optim import AdamW, Lion is_paged = False optim_bits = 32 optimizer_cls = None additional_optim_kwargs = adam_kwargs if "paged" in args.optim: is_paged = True if "8bit" in args.optim: optim_bits = 8 if "adam" in args.optim: optimizer_cls = AdamW elif "lion" in args.optim: optimizer_cls = Lion additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} optimizer_kwargs.update(additional_optim_kwargs) optimizer_kwargs.update(bnb_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") elif args.optim == OptimizerNames.ADAMW_BNB: try: from bitsandbytes.optim import Adam8bit optimizer_cls = Adam8bit optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: try: from torchdistx.optimizers import AnyPrecisionAdamW optimizer_cls = AnyPrecisionAdamW optimizer_kwargs.update(adam_kwargs) # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. optimizer_kwargs.update( { "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), "compensation_buffer_dtype": getattr( torch, optim_args.get("compensation_buffer_dtype", "bfloat16") ), } ) except ImportError: raise ValueError("Please install https://github.com/pytorch/torchdistx") elif args.optim == OptimizerNames.SGD: optimizer_cls = torch.optim.SGD elif args.optim == OptimizerNames.ADAGRAD: optimizer_cls = torch.optim.Adagrad else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. Args: num_training_steps (int): The number of training steps to do. """ if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, optimizer=self.optimizer if optimizer is None else optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) self._created_lr_scheduler = True return self.lr_scheduler def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When dataloader.dataset does not exist or has no length, estimates as best it can """ try: dataset = dataloader.dataset # Special case for IterableDatasetShard, we need to dig deeper if isinstance(dataset, IterableDatasetShard): return len(dataloader.dataset.dataset) return len(dataloader.dataset) except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader return len(dataloader) * self.args.per_device_train_batch_size def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """HP search setup code""" self._trial = trial if self.hp_search_backend is None or trial is None: return if self.hp_search_backend == HPSearchBackend.OPTUNA: params = self.hp_space(trial) elif self.hp_search_backend == HPSearchBackend.RAY: params = trial params.pop("wandb", None) elif self.hp_search_backend == HPSearchBackend.SIGOPT: params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} elif self.hp_search_backend == HPSearchBackend.WANDB: params = trial for key, value in params.items(): if not hasattr(self.args, key): logger.warning( f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" " `TrainingArguments`." ) continue old_attr = getattr(self.args, key, None) # Casting value to the proper type if old_attr is not None: value = type(old_attr)(value) setattr(self.args, key, value) if self.hp_search_backend == HPSearchBackend.OPTUNA: logger.info(f"Trial: {trial.params}") if self.hp_search_backend == HPSearchBackend.SIGOPT: logger.info(f"SigOpt Assignments: {trial.assignments}") if self.hp_search_backend == HPSearchBackend.WANDB: logger.info(f"W&B Sweep parameters: {trial}") if self.is_deepspeed_enabled: if self.args.deepspeed is None: raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") # Rebuild the deepspeed config to reflect the updated training parameters from accelerate.utils import DeepSpeedPlugin from transformers.deepspeed import HfTrainerDeepSpeedConfig self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) self.create_accelerator_and_postprocess() def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): if self.hp_search_backend is None or trial is None: return self.objective = self.compute_objective(metrics.copy()) if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna trial.report(self.objective, step) if trial.should_prune(): self.callback_handler.on_train_end(self.args, self.state, self.control) raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: from ray import tune if self.control.should_save: self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) def _tune_save_checkpoint(self): from ray import tune if not self.use_tune_checkpoints: return with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir, _internal_call=True) if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) def call_model_init(self, trial=None): model_init_argcount = number_of_arguments(self.model_init) if model_init_argcount == 0: model = self.model_init() elif model_init_argcount == 1: model = self.model_init(trial) else: raise RuntimeError("model_init should have 0 or 1 argument.") if model is None: raise RuntimeError("model_init should not return None.") return model def torch_jit_model_eval(self, model, dataloader, training=False): if not training: if dataloader is None: logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") return model example_batch = next(iter(dataloader)) example_batch = self._prepare_inputs(example_batch) try: jit_model = copy.copy(model) jit_model.eval() original_forward = jit_model.__dict__.pop("_original_forward", None) # remove mixed precision hooks from the model if original_forward: jit_model.forward = original_forward with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): if isinstance(example_batch, dict): jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) else: jit_model = torch.jit.trace( jit_model, example_kwarg_inputs={key: example_batch[key] for key in example_batch}, strict=False, ) else: jit_inputs = [] for key in example_batch: example_tensor = torch.ones_like(example_batch[key]) jit_inputs.append(example_tensor) jit_inputs = tuple(jit_inputs) jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) jit_model = torch.jit.freeze(jit_model) with torch.no_grad(): jit_model(**example_batch) jit_model(**example_batch) model = jit_model self.use_cpu_amp = False self.use_cuda_amp = False except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: logger.warning(f"failed to use PyTorch jit mode due to: {e}.") return model def ipex_optimize_model(self, model, training=False, dtype=torch.float32): if not is_ipex_available(): raise ImportError( "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" " to https://github.com/intel/intel-extension-for-pytorch." ) import intel_extension_for_pytorch as ipex if not training: model.eval() dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) else: if not model.training: model.train() model, self.optimizer = ipex.optimize( model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" ) return model def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training (should be after apex fp16 initialization) if self.sharded_ddp is not None: # Sharded DDP! if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: model = auto_wrap(model) self.model = model = FullyShardedDDP( model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP elif self.fsdp is not None and self.args.fsdp_config["xla"]: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: if is_torch_neuroncore_available(): return model kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if self.args.ddp_broadcast_buffers is not None: kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None, ignore_keys_for_eval: Optional[List[str]] = None, **kwargs, ): """ Main training entry point. Args: resume_from_checkpoint (`str` or `bool`, *optional*): If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. ignore_keys_for_eval (`List[str]`, *optional*) A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments used to hide deprecated arguments """ if resume_from_checkpoint is False: resume_from_checkpoint = None # memory metrics - must set up as early as possible self._memory_tracker.start() args = self.args self.is_in_train = True # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " "instead.", FutureWarning, ) if len(kwargs) > 0: raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.model = self.call_model_init(trial) model_reloaded = True # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: resume_from_checkpoint = get_last_checkpoint(args.output_dir) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: if self.place_model_on_device: self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model inner_training_loop = find_executable_batch_size( self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) return inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): self.accelerator.free_memory() self._train_batch_size = batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloaderd2() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if args.max_steps<0: args.max_steps=100 if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps and args.logging_steps < 1: args.logging_steps = math.ceil(max_steps * args.logging_steps) if args.eval_steps and args.eval_steps < 1: args.eval_steps = math.ceil(max_steps * args.eval_steps) if args.save_steps and args.save_steps < 1: args.save_steps = math.ceil(max_steps * args.save_steps) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torch.distributed.launch)." ) else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = ( self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() or self.fsdp is not None ) # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None self._created_lr_scheduler = False if self.is_deepspeed_enabled: self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() model = self._wrap_model(self.model_wrapped) if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() if hasattr(self.lr_scheduler, "step"): if self.use_apex: model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # deepspeed ckpt loading if resume_from_checkpoint is not None and self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)#, load_optimizer_states=self.args.load_optimizer_states, load_lr_scheduler_states=self.args.load_lr_scheduler_states) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader if self.hp_name is not None and self._trial is not None: # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial # parameter to Train when using DDP. self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) else: self.state.trial_params = None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss_ = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): for _ in train_dataloader: break total_batched_samples = 0 tr_loss = dict() for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 # if steps_trained_in_current_epoch > 0: # epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) # steps_skipped = steps_trained_in_current_epoch # steps_trained_in_current_epoch = 0 # rng_to_sync = True step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch =0 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(steps_trained_in_current_epoch) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) if len(tr_loss)==0: tr_loss={k:tr_loss_.clone() for k in tr_loss_step.keys()} for k, loss in tr_loss.items(): if ( args.logging_nan_inf_filter and not is_torch_tpu_available() and (torch.isnan(tr_loss_step[k]) or torch.isinf(tr_loss_step[k])) ): # if loss is nan or inf simply add the average of previous logged losses tr_loss[k] += loss / (1 + self.state.global_step - self._globalstep_last_logged) else: tr_loss[k] += tr_loss_step[k] # if ( # args.logging_nan_inf_filter # and not is_torch_tpu_available() # and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) # ): # # if loss is nan or inf simply add the average of previous logged losses # tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) # else: # tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) is_last_step_and_steps_less_than_grad_acc = ( steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ) if ( total_batched_samples % args.gradient_accumulation_steps == 0 or # last step in epoch but step is always smaller than gradient_accumulation_steps is_last_step_and_steps_less_than_grad_acc ): # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered # in accelerate. So, explicitly enable sync gradients to True in that case. if is_last_step_and_steps_less_than_grad_acc or ( version.parse(accelerate_version) <= version.parse("0.20.3") ): self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if self.do_grad_scaling: # Reduce gradients first for XLA if is_torch_tpu_available(): gradients = xm._fetch_gradients(self.optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) # AMP: gradients need unscaling self.scaler.unscale_(self.optimizer) if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm, ) else: self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) # Optimizer step optimizer_was_run = True if is_torch_tpu_available(): if self.do_grad_scaling: self.scaler.step(self.optimizer) self.scaler.update() else: # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step self.optimizer.step() elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler.get_scale() optimizer_was_run = scale_before <= scale_after else: self.optimizer.step() optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: break if step < 0: logger.warning( "There seems to be not a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() self._load_best_model() # add remaining tr_loss # self._total_loss_scalar += tr_loss.item() self._total_loss_scalar += tr_loss['loss_total'].item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if checkpoint != self.state.best_model_checkpoint: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) return TrainOutput(self.state.global_step, train_loss, metrics) def _get_output_dir(self, trial): if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend == HPSearchBackend.OPTUNA: run_id = trial.number elif self.hp_search_backend == HPSearchBackend.RAY: from ray import tune run_id = tune.get_trial_id() elif self.hp_search_backend == HPSearchBackend.SIGOPT: run_id = trial.id elif self.hp_search_backend == HPSearchBackend.WANDB: import wandb run_id = wandb.run.id run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" run_dir = os.path.join(self.args.output_dir, run_name) else: run_dir = self.args.output_dir return run_dir def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if model is None: model = self.model config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) if not any( os.path.isfile(f) for f in [ weights_file, safe_weights_file, weights_index_file, safe_weights_index_file, adapter_weights_file, adapter_safe_weights_file, ] ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}.") if os.path.isfile(config_file): config = PretrainedConfig.from_json_file(config_file) checkpoint_version = config.transformers_version if checkpoint_version is not None and checkpoint_version != __version__: logger.warning( f"You are resuming training from a checkpoint trained with {checkpoint_version} of " f"Transformers but your current version is {__version__}. This is not recommended and could " "yield to errors or unwanted behaviors." ) if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. # Checkpoint must have been saved with the new smp api. smp.resume_from_checkpoint( path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False ) else: # If the 'user_content.pt' file does NOT exist, load with the old smp api. # Checkpoint must have been saved with the old smp api. if hasattr(self.args, "fp16") and self.args.fp16 is True: logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) state_dict = torch.load(weights_file, map_location="cpu") # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) # release memory del state_dict elif self.is_fsdp_enabled: load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint) else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(safe_weights_file): state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") else: state_dict = torch.load(weights_file, map_location="cpu") # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) # release memory del state_dict self._issue_warnings_after_load(load_result) # Load adapters following PR # 24096 elif is_peft_available() and isinstance(model, PeftModel): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if os.path.exists(resume_from_checkpoint): model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " "Check some examples here: https://github.com/huggingface/peft/issues/96" ) else: logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") else: # We load the sharded checkpoint load_result = load_sharded_checkpoint( model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if ( os.path.exists(best_model_path) or os.path.exists(best_safe_model_path) or os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path) ): if self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) else: has_been_loaded = True if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. # Checkpoint must have been saved with the new smp api. smp.resume_from_checkpoint( path=self.state.best_model_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False, ) else: # If the 'user_content.pt' file does NOT exist, load with the old smp api. # Checkpoint must have been saved with the old smp api. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: state_dict = torch.load(best_model_path, map_location="cpu") state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) elif self.is_fsdp_enabled: load_fsdp_model( self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint ) else: if is_peft_available() and isinstance(model, PeftModel): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys load_result = _IncompatibleKeys([], []) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " "Check some examples here: https://github.com/huggingface/peft/issues/96" ) has_been_loaded = False else: logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: state_dict = torch.load(best_model_path, map_location="cpu") # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): load_result = load_sharded_checkpoint( model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) else: logger.warning( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " "on multiple nodes, you should activate `--save_on_each_node`." ) def _issue_warnings_after_load(self, load_result): if len(load_result.missing_keys) != 0: if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( self.model._keys_to_ignore_on_save ): self.model.tie_weights() else: logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") if len(load_result.unexpected_keys) != 0: logger.warning( f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." ) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log: if is_torch_tpu_available(): xm.mark_step() logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes # tr_loss_scalar = self._nested_gather(tr_loss).mean().item() tr_loss_scalar = {k: self._nested_gather(tr_loss[k]).mean().item() for k in tr_loss.keys()} # reset tr_loss to zero for _,loss in tr_loss.items(): loss -= loss # tr_loss -= tr_loss # logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) for k,loss in tr_loss_scalar.items(): logs[k]=round(loss / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar['loss_total'] self._globalstep_last_logged = self.state.global_step self.store_flos() self.log(logs) metrics = None if self.control.should_evaluate: if isinstance(self.eval_dataset, dict): metrics = {} for eval_dataset_name, eval_dataset in self.eval_dataset.items(): dataset_metrics = self.evaluate( eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval, metric_key_prefix=f"eval_{eval_dataset_name}", ) metrics.update(dataset_metrics) else: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) # Run delayed LR scheduler now that metrics are populated if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _load_rng_state(self, checkpoint): # Load RNG states from `checkpoint` if checkpoint is None: return if self.args.world_size > 1: process_index = self.args.process_index rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") if not os.path.isfile(rng_file): logger.info( f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " "wasn't launched in a distributed fashion, reproducibility is not guaranteed." ) return else: rng_file = os.path.join(checkpoint, "rng_state.pth") if not os.path.isfile(rng_file): logger.info( "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " "fashion, reproducibility is not guaranteed." ) return checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) else: try: torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) except Exception as e: logger.info( f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" "\nThis won't yield the same results as if the training had not been interrupted." ) if is_torch_tpu_available(): xm.set_rng_state(checkpoint_rng_state["xla"]) def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is None and trial is None: self.store_flos() run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer.consolidate_state_dict() if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: save_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir ) else: # FSDP has a different interface for saving optimizer states. # Needs to be called on all ranks to gather all states. # full_optim_state_dict will be deprecated after Pytorch 2.2! full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) smp.barrier() if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: smp.save( opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=True, v3=smp.state.cfg.shard_optimizer_state, ) if self.args.should_save: with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched if self.fsdp and not self.is_fsdp_enabled: torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) else: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if ( self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric) ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir # Save the Trainer state if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) # Save RNG state in non-distributed training rng_states = { "python": random.getstate(), "numpy": np.random.get_state(), "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: rng_states["cuda"] = torch.cuda.random.get_rng_state() if is_torch_tpu_available(): rng_states["xla"] = xm.get_rng_state() # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) if self.args.world_size <= 1: torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) else: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) if self.args.push_to_hub: self._push_from_checkpoint(output_dir) # Maybe delete some older checkpoints. if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: return if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") if is_sagemaker_mp_enabled() else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): # Optimizer checkpoint was saved with smp >= 1.10 def opt_load_hook(mod, opt): opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) else: # Optimizer checkpoint was saved with smp < 1.10 def opt_load_hook(mod, opt): if IS_SAGEMAKER_MP_POST_1_10: opt.load_state_dict( smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) ) else: opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) self.model_wrapped.register_post_step_hook(opt_load_hook) else: # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = self.args.device if self.args.world_size > 1 else "cpu" if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: load_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, checkpoint, ) else: full_osd = None # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it if self.args.process_index == 0: full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) # call scatter_full_optim_state_dict on all ranks sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) self.optimizer.load_state_dict(sharded_osd) else: self.optimizer.load_state_dict( torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, n_trials: int = 20, direction: str = "minimize", backend: Optional[Union["str", HPSearchBackend]] = None, hp_name: Optional[Callable[["optuna.Trial"], str]] = None, **kwargs, ) -> BestRun: """ Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, the sum of all metrics otherwise. To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom optimizer/scheduler. Args: hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): A function that defines the hyperparameter search space. Will default to [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` method. Will default to [`~trainer_utils.default_compute_objective`]. n_trials (`int`, *optional*, defaults to 100): The number of trial runs to test. direction (`str`, *optional*, defaults to `"minimize"`): Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending on which one is installed. If all are installed, will default to optuna. hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): A function that defines the trial/run name. Will default to None. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more information see: - the documentation of [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) Returns: [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in `run_summary` attribute for Ray backend. """ if backend is None: backend = default_hp_search_backend() backend = HPSearchBackend(backend) backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() backend_obj.ensure_available() self.hp_search_backend = backend if self.model_init is None: raise RuntimeError( "To use hyperparameter search, you need to pass your model through a model_init function." ) self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective best_run = backend_obj.run(self, n_trials, direction, **kwargs) self.hp_search_backend = None return best_run def log(self, logs: Dict[str, float]) -> None: """ Log `logs` on the various objects watching training. Subclass and override this method to inject custom behavior. Args: logs (`Dict[str, float]`): The values to log. """ if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 2) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: """ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. """ if isinstance(data, Mapping): return type(data)({k: self._prepare_input(v) for k, v in data.items()}) elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.args.device} if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): # NLP models inputs are int/uint and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) return data.to(**kwargs) return data def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ inputs = self._prepare_input(inputs) if len(inputs) == 0: raise ValueError( "The batch received was empty, your model won't be able to train on it. Double-check that your " f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." ) if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past return inputs def compute_loss_context_manager(self): """ A helper wrapper to group together context managers. """ return self.autocast_smart_context_manager() def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): """ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired arguments, depending on the situation. """ if self.use_cuda_amp or self.use_cpu_amp: ctx_manager = ( torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) if self.use_cpu_amp else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) ) else: ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() return ctx_manager def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. Return: `torch.Tensor`: The tensor with training loss on this batch. """ model.train() inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: for k, ls in loss.items(): loss[k] = loss[k].mean() # mean() to average on multi-gpu parallel training if self.do_grad_scaling: self.scaler.scale(loss['loss_total']).backward() elif self.use_apex: with amp.scale_loss(loss['loss_total'], self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss['loss_total']) # return loss.detach() / self.args.gradient_accumulation_steps return {k:v.detach()/self.args.gradient_accumulation_steps for k,v in loss.items()} def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: labels = None outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if labels is not None: if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process. """ return self.args.local_process_index == 0 def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be `True` for one process). """ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global # process index. if is_sagemaker_mp_enabled(): return smp.rank() == 0 else: return self.args.process_index == 0 def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): """ Will save the model, so you can reload it using `from_pretrained()`. Will only save from the main process. """ if output_dir is None: output_dir = self.args.output_dir if is_torch_tpu_available(): self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. os.makedirs(output_dir, exist_ok=True) state_dict = self.model_wrapped.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() elif ( ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None or self.is_fsdp_enabled ): state_dict = self.model.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) if self.is_fsdp_enabled: save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 if version.parse(accelerate_version) <= version.parse("0.20.3"): raise ValueError("Install Accelerate from main branch") try: state_dict = self.accelerator.get_state_dict(self.deepspeed) if self.args.should_save: self._save(output_dir, state_dict=state_dict) except ValueError: logger.warning( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: self._save(output_dir) # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info(f"Saving model checkpoint to {output_dir}") if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if not isinstance(self.model, PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel): unwrap_model(self.model).save_pretrained( output_dir, is_main_process=self.args.should_save, state_dict=self.model.state_dict(), save_function=xm.save, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() if isinstance(unwrap_model(self.model), supported_classes): unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def store_flos(self): # Storing the number of floating-point operations that went into the model if self.args.parallel_mode == ParallelMode.DISTRIBUTED: self.state.total_flos += ( distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() ) self.current_flos = 0 else: self.state.total_flos += self.current_flos self.current_flos = 0 def _sorted_checkpoints( self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False ) -> List[str]: ordering_and_checkpoint_path = [] glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] for path in glob_checkpoints: if use_mtime: ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) else: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match is not None and regex_match.groups() is not None: ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] # Make sure we don't delete the best model. if self.state.best_model_checkpoint is not None: best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) for i in range(best_model_index, len(checkpoints_sorted) - 2): checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Check if we should delete older checkpoint(s) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) if len(checkpoints_sorted) <= self.args.save_total_limit: return # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which # we don't do to allow resuming. save_total_limit = self.args.save_total_limit if ( self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1 and checkpoints_sorted[-1] != self.state.best_model_checkpoint ): save_total_limit = 2 number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init `compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (`Dataset`, *optional*): Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` method. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is "eval" (default) Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ # memory metrics - must set up as early as possible self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop output = eval_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, ) total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, start_time, num_samples=output.num_samples, num_steps=math.ceil(output.num_samples / total_batch_size), ) ) self.log(output.metrics) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return output.metrics def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in `evaluate()`. Args: test_dataset (`Dataset`): Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` method are automatically removed. Has to implement the method `__len__` ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"test"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "test_bleu" if the prefix is "test" (default) If your predictions or labels have different sequence length (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100. Returns: *NamedTuple* A namedtuple with the following keys: - predictions (`np.ndarray`): The predictions on `test_dataset`. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ # memory metrics - must set up as early as possible self._memory_tracker.start() test_dataloader = self.get_test_dataloader(test_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop output = eval_loop( test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, start_time, num_samples=output.num_samples, num_steps=math.ceil(output.num_samples / total_batch_size), ) ) self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) def evaluation_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: """ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ args = self.args prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.deepspeed is None: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: if args.fp16_full_eval: model = model.to(dtype=torch.float16, device=args.device) elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = self.args.eval_batch_size logger.info(f"***** Running {description} *****") if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: logger.info(" Num examples: Unknown") logger.info(f" Batch size = {batch_size}") model.eval() self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. eval_dataset = getattr(dataloader, "dataset", None) if args.past_index >= 0: self._past = None # Initialize containers # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) losses_host = None preds_host = None labels_host = None inputs_host = None # losses/preds/labels on CPU (final containers) all_losses = None all_preds = None all_labels = None all_inputs = None # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 # Main evaluation loop for step, inputs in enumerate(dataloader): # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: observed_num_examples += observed_batch_size # For batch samplers, batch_size is not known by the dataloader in advance. if batch_size is None: batch_size = observed_batch_size # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() # Update containers on host if loss is not None: losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_host = ( inputs_decode if inputs_host is None else nested_concat(inputs_host, inputs_decode, padding_index=-100) ) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.accelerator.gather_for_metrics((logits)) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels = self.accelerator.gather_for_metrics((labels)) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) if inputs_host is not None: inputs_decode = nested_numpify(inputs_host) all_inputs = ( inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = ( labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) ) # Set back to None to begin a new accumulation losses_host, preds_host, inputs_host, labels_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) if inputs_host is not None: inputs_decode = nested_numpify(inputs_host) all_inputs = ( inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) # Number of samples if has_length(eval_dataset): num_samples = len(eval_dataset) # The instance check is weird and does not actually check for the type, but whether the dataset has the right # methods. Therefore we need to make sure it also has the attribute. elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: num_samples = eval_dataset.num_examples else: if has_length(dataloader): num_samples = self.num_examples(dataloader) else: # both len(dataloader.dataset) and len(dataloader) fail num_samples = observed_num_examples if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples # Metrics! if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) else: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) def _nested_gather(self, tensors, name=None): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return if is_torch_tpu_available(): if name is None: name = "nested_gather" tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( self.args.distributed_state is None and self.args.local_rank != -1 ): tensors = distributed_concat(tensors) return tensors def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to evaluate. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): Whether or not to return the loss only. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. Return: Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) # For CLIP-like models capable of returning loss values. # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` # is `True` in `model.forward`. return_loss = inputs.get("return_loss", None) if return_loss is None: return_loss = self.can_return_loss loss_without_labels = True if len(self.label_names) == 0 and return_loss else False inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None with torch.no_grad(): if is_sagemaker_mp_enabled(): raw_outputs = smp_forward_only(model, inputs) if has_labels or loss_without_labels: if isinstance(raw_outputs, dict): loss_mb = raw_outputs["loss"] logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) else: loss_mb = raw_outputs[0] logits_mb = raw_outputs[1:] loss = loss_mb.reduce_mean().detach().cpu() logits = smp_nested_concat(logits_mb) else: loss = None if isinstance(raw_outputs, dict): logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) else: logits_mb = raw_outputs logits = smp_nested_concat(logits_mb) else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs[1:] else: loss = None with self.compute_loss_context_manager(): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: logits = outputs # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] return (loss, logits, labels) def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): """ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point operations for every backward + forward pass. If using another model, either implement such a method in the model or subclass and override this method. Args: inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. Returns: `int`: The number of floating-point operations. """ if hasattr(self.model, "floating_point_ops"): return self.model.floating_point_ops(inputs) else: return 0 def init_git_repo(self, at_init: bool = False): """ Initializes a git repo in `self.args.hub_model_id`. Args: at_init (`bool`, *optional*, defaults to `False`): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ if not self.is_world_process_zero(): return if self.args.hub_model_id is None: repo_name = Path(self.args.output_dir).absolute().name else: repo_name = self.args.hub_model_id if "/" not in repo_name: repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) # Make sure the repo exists. create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) try: self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) except EnvironmentError: if self.args.overwrite_output_dir and at_init: # Try again after wiping output_dir shutil.rmtree(self.args.output_dir) self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) else: raise self.repo.git_pull() # By default, ignore the checkpoint folders if ( not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS ): with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: writer.writelines(["checkpoint-*/"]) # Add "*.sagemaker" to .gitignore if using SageMaker if os.environ.get("SM_TRAINING_ENV"): self._add_sm_patterns_to_gitignore() self.push_in_progress = None def create_model_card( self, language: Optional[str] = None, license: Optional[str] = None, tags: Union[str, List[str], None] = None, model_name: Optional[str] = None, finetuned_from: Optional[str] = None, tasks: Union[str, List[str], None] = None, dataset_tags: Union[str, List[str], None] = None, dataset: Union[str, List[str], None] = None, dataset_args: Union[str, List[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: language (`str`, *optional*): The language of the model (if applicable) license (`str`, *optional*): The license of the model. Will default to the license of the pretrained model used, if the original model given to the `Trainer` comes from a repo on the Hub. tags (`str` or `List[str]`, *optional*): Some tags to be included in the metadata of the model card. model_name (`str`, *optional*): The name of the model. finetuned_from (`str`, *optional*): The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo of the original model given to the `Trainer` (if it comes from the Hub). tasks (`str` or `List[str]`, *optional*): One or several task identifiers, to be included in the metadata of the model card. dataset_tags (`str` or `List[str]`, *optional*): One or several dataset tags, to be included in the metadata of the model card. dataset (`str` or `List[str]`, *optional*): One or several dataset identifiers, to be included in the metadata of the model card. dataset_args (`str` or `List[str]`, *optional*): One or several dataset arguments, to be included in the metadata of the model card. """ if not self.is_world_process_zero(): return training_summary = TrainingSummary.from_trainer( self, language=language, license=license, tags=tags, model_name=model_name, finetuned_from=finetuned_from, tasks=tasks, dataset_tags=dataset_tags, dataset=dataset, dataset_args=dataset_args, ) model_card = training_summary.to_model_card() with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: f.write(model_card) def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: return # If we haven't finished the last push, we don't do this one. if self.push_in_progress is not None and not self.push_in_progress.is_done: return output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] if is_peft_available(): modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) try: if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Temporarily move the checkpoint just saved for the push tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a # subfolder. if os.path.isdir(tmp_checkpoint): shutil.rmtree(tmp_checkpoint) shutil.move(checkpoint_folder, tmp_checkpoint) if self.args.save_strategy == IntervalStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" else: commit_message = f"Training in progress, epoch {int(self.state.epoch)}" push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True) # Return type of `Repository.push_to_hub` is either None or a tuple. if push_work is not None: self.push_in_progress = push_work[1] except Exception as e: logger.error(f"Error when pushing to hub: {e}") finally: if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Move back the checkpoint to its place shutil.move(tmp_checkpoint, checkpoint_folder) def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: """ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. Parameters: commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to [`~Trainer.create_model_card`]. Returns: The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the commit and an object to track the progress of the commit if `blocking=True` """ # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but # it might fail. if not hasattr(self, "repo"): self.init_git_repo() model_name = kwargs.pop("model_name", None) if model_name is None and self.args.should_save: if self.args.hub_model_id is None: model_name = Path(self.args.output_dir).name else: model_name = self.args.hub_model_id.split("/")[-1] # Needs to be executed on all processes for TPU training, but will only save on the processed determined by # self.args.should_save. self.save_model(_internal_call=True) # Only push from one node. if not self.is_world_process_zero(): return # Cancel any async push in progress if blocking=True. The commits will all be pushed together. if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: self.push_in_progress._process.kill() self.push_in_progress = None git_head_commit_url = self.repo.push_to_hub( commit_message=commit_message, blocking=blocking, auto_lfs_prune=True ) # push separately the model card to be independant from the rest of the model if self.args.should_save: self.create_model_card(model_name=model_name, **kwargs) try: self.repo.push_to_hub( commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True ) except EnvironmentError as exc: logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") return git_head_commit_url # # Deprecated code # def prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: """ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ args = self.args if not has_length(dataloader): raise ValueError("dataloader must implement a working __len__") prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.deepspeed is None: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: if args.fp16_full_eval: model = model.to(dtype=torch.float16, device=args.device) elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None world_size = max(1, args.world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass # a batch size to the sampler) make_multiple_of = None if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): make_multiple_of = dataloader.sampler.batch_size preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() if args.past_index >= 0: self._past = None self.callback_handler.eval_dataloader = dataloader for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) if logits is not None: preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) if inputs_decode is not None: inputs_host = ( inputs_decode if inputs_host is None else nested_concat(inputs_host, inputs_decode, padding_index=-100) ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation losses_host, preds_host, labels_host, inputs_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) eval_loss = eval_losses_gatherer.finalize() preds = preds_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None if self.compute_metrics is not None and preds is not None and label_ids is not None: if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) if eval_loss is not None: metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) def _gather_and_numpify(self, tensors, name): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return if is_torch_tpu_available(): tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: tensors = distributed_concat(tensors) return nested_numpify(tensors) def _add_sm_patterns_to_gitignore(self) -> None: """Add SageMaker Checkpointing patterns to .gitignore file.""" # Make sure we only do this on the main process if not self.is_world_process_zero(): return patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] # Get current .gitignore content if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: current_content = f.read() else: current_content = "" # Add the patterns to .gitignore content = current_content for pattern in patterns: if pattern not in content: if content.endswith("\n"): content += pattern else: content += f"\n{pattern}" # Write the .gitignore file if it has changed if content != current_content: with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: logger.debug(f"Writing .gitignore file. Content: {content}") f.write(content) self.repo.git_add(".gitignore") # avoid race condition with git status time.sleep(0.5) if not self.repo.is_repo_clean(): self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") self.repo.git_push() def create_accelerator_and_postprocess(self): grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} if version.parse(accelerate_version) > version.parse("0.20.3"): grad_acc_kwargs["sync_with_dataloader"] = False gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) # create accelerator object self.accelerator = Accelerator( deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin ) # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params) if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: from transformers.deepspeed import HfTrainerDeepSpeedConfig ds_plugin = self.accelerator.state.deepspeed_plugin ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args) class LLaVATrainer(TrainerLLavaGD): def _save_checkpoint(self, model, trial, metrics=None): # if getattr(self.args, 'tune_mm_mlp_adapter', False): # from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR # checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" # # run_dir = self._get_output_dir(trial=trial) # output_dir = os.path.join(run_dir, checkpoint_folder) # # # Only save Adapter # keys_to_match = ['mm_projector'] # if getattr(self.args, "use_im_start_end", False) or getattr(self.args, "new_tokens", False): # keys_to_match.extend(['embed_tokens', 'embed_in','lm_head']) # # import pdb; pdb.set_trace() # weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) # # if self.args.local_rank == 0 or self.args.local_rank == -1: # self.model.config.save_pretrained(output_dir) # torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) # else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): # if getattr(self.args, 'tune_mm_mlp_adapter', False): # pass # else: super(LLaVATrainer, self)._save(output_dir, state_dict) ================================================ FILE: llava/train/train.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" dbg: bool = False def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: sources = self.list_data_dict[i] # print(1,'\n') if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: # print(2) # print(2, '\n') image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: # print(3, '\n') image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) # print(4,'\n') if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) # print(6,'\n') return data_dict except Exception as e: print(self.list_data_dict[i]['image'], "failed") return self.__getitem__(i + 1) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_grounding_1st.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer_gd import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") whole_model: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) load_model: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") opt: Optional[str] = field(default="") config_file: Optional[str] = field(default="") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False new_tokens: bool = True lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" dbg: bool = False load_optimizer_states: bool = True load_lr_scheduler_states: bool = True freeze_segmentation: bool = False def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict except Exception: print(self.list_data_dict[i], "failed") return self.__getitem__(i + 1) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) from detectron2.config import LazyConfig, instantiate def setup(args): """ Create configs and perform basic setups. """ cfg = LazyConfig.load(args.config_file) # import pdb;pdb.set_trace() opt=args.opt.split(',') cfg = LazyConfig.apply_overrides(cfg, opt) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) cfg=setup(model_args) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM_gd.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter or training_args.dbg: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) model.initialize_seg_modules( cfg=cfg, ) if training_args.freeze_segmentation: model.freeze_seg_modules() if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) print(model) if model_args.load_model: loaded_dict = dict() if "stage1" in model_args.whole_model: old_emb_in=model.get_input_embeddings().weight.clone() old_emb_out=model.get_output_embeddings().weight.clone() for model_file in os.listdir(model_args.whole_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu')) model.load_state_dict(loaded_dict, strict=False) if "stage1" in model_args.whole_model: with torch.no_grad(): model.get_input_embeddings().weight[:-3]=old_emb_in[:-3] model.get_output_embeddings().weight[:-3]=old_emb_out[:-3] print(loaded_dict.keys()) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess), **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_joint_1st.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer_joint_train import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") whole_model: Optional[str] = field(default="") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) load_model: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") opt: Optional[str] = field(default="") config_file: Optional[str] = field(default="") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False new_tokens: bool = True lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" dbg: bool = False load_optimizer_states: bool = True load_lr_scheduler_states: bool = True def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict except Exception: print(self.list_data_dict[i], "failed") return self.__getitem__(i + 1) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch @dataclass class DataCollatorForSupervisedDatasetEmpty(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return instances # input_ids, labels = tuple([instance[key] for instance in instances] # for key in ("input_ids", "labels")) # input_ids = torch.nn.utils.rnn.pad_sequence( # input_ids, # batch_first=True, # padding_value=self.tokenizer.pad_token_id) # labels = torch.nn.utils.rnn.pad_sequence(labels, # batch_first=True, # padding_value=IGNORE_INDEX) # input_ids = input_ids[:, :self.tokenizer.model_max_length] # labels = labels[:, :self.tokenizer.model_max_length] # batch = dict( # input_ids=input_ids, # labels=labels, # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # ) # # if 'image' in instances[0]: # images = [instance['image'] for instance in instances] # if all(x is not None and x.shape == images[0].shape for x in images): # batch['images'] = torch.stack(images) # else: # batch['images'] = images # # return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) from detectron2.config import LazyConfig, instantiate def setup(args): """ Create configs and perform basic setups. """ cfg = LazyConfig.load(args.config_file) # import pdb;pdb.set_trace() opt=args.opt.split(',') cfg = LazyConfig.apply_overrides(cfg, opt) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) cfg=setup(model_args) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM_joint.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter or training_args.dbg: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size model.initialize_seg_modules( cfg=cfg ) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) print(model) if model_args.load_model: loaded_dict = dict() if "stage1" in model_args.whole_model: old_emb_in=model.get_input_embeddings().weight.clone() old_emb_out=model.get_output_embeddings().weight.clone() for model_file in os.listdir(model_args.whole_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu')) model.load_state_dict(loaded_dict, strict=False) if "stage1" in model_args.whole_model: with torch.no_grad(): model.get_input_embeddings().weight[:-3]=old_emb_in[:-3] model.get_output_embeddings().weight[:-3]=old_emb_out[:-3] print(loaded_dict.keys()) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess), **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_joint_2st.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer_joint_train import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") whole_model: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) load_model: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") opt: Optional[str] = field(default="") config_file: Optional[str] = field(default="") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False new_tokens: bool = True lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" dbg: bool = False load_optimizer_states: bool = True load_lr_scheduler_states: bool = True def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict except Exception: print(self.list_data_dict[i], "failed") return self.__getitem__(i + 1) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch @dataclass class DataCollatorForSupervisedDatasetEmpty(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return instances # input_ids, labels = tuple([instance[key] for instance in instances] # for key in ("input_ids", "labels")) # input_ids = torch.nn.utils.rnn.pad_sequence( # input_ids, # batch_first=True, # padding_value=self.tokenizer.pad_token_id) # labels = torch.nn.utils.rnn.pad_sequence(labels, # batch_first=True, # padding_value=IGNORE_INDEX) # input_ids = input_ids[:, :self.tokenizer.model_max_length] # labels = labels[:, :self.tokenizer.model_max_length] # batch = dict( # input_ids=input_ids, # labels=labels, # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # ) # # if 'image' in instances[0]: # images = [instance['image'] for instance in instances] # if all(x is not None and x.shape == images[0].shape for x in images): # batch['images'] = torch.stack(images) # else: # batch['images'] = images # # return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) from detectron2.config import LazyConfig, instantiate def setup(args): """ Create configs and perform basic setups. """ cfg = LazyConfig.load(args.config_file) # import pdb;pdb.set_trace() opt=args.opt.split(',') cfg = LazyConfig.apply_overrides(cfg, opt) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) cfg=setup(model_args) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM_joint_2st.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter or training_args.dbg: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size model.initialize_seg_modules( cfg=cfg ) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) print(model) if model_args.load_model: loaded_dict = dict() if "stage1" in model_args.whole_model: old_emb_in=model.get_input_embeddings().weight.clone() old_emb_out=model.get_output_embeddings().weight.clone() for model_file in os.listdir(model_args.whole_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu')) model.load_state_dict(loaded_dict, strict=False) if "stage1" in model_args.whole_model: with torch.no_grad(): model.get_input_embeddings().weight[:-3]=old_emb_in[:-3] model.get_output_embeddings().weight[:-3]=old_emb_out[:-3] print(loaded_dict.keys()) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess), **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_joint_2st_interactive_refcoco_coco_instruction.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer_joint_train import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token,tokenizer_image_token_inter from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") whole_model: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) tune_prompt_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) load_model: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") opt: Optional[str] = field(default="") config_file_gd: Optional[str] = field(default="") config_file_it: Optional[str] = field(default="") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False new_tokens: bool = True lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" dbg: bool = False load_optimizer_states: bool = True load_lr_scheduler_states: bool = True def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt().replace("", "% ")) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token_inter(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict except Exception: print(self.list_data_dict[i], "failed") return self.__getitem__(i + 1) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch @dataclass class DataCollatorForSupervisedDatasetEmpty(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return instances # input_ids, labels = tuple([instance[key] for instance in instances] # for key in ("input_ids", "labels")) # input_ids = torch.nn.utils.rnn.pad_sequence( # input_ids, # batch_first=True, # padding_value=self.tokenizer.pad_token_id) # labels = torch.nn.utils.rnn.pad_sequence(labels, # batch_first=True, # padding_value=IGNORE_INDEX) # input_ids = input_ids[:, :self.tokenizer.model_max_length] # labels = labels[:, :self.tokenizer.model_max_length] # batch = dict( # input_ids=input_ids, # labels=labels, # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # ) # # if 'image' in instances[0]: # images = [instance['image'] for instance in instances] # if all(x is not None and x.shape == images[0].shape for x in images): # batch['images'] = torch.stack(images) # else: # batch['images'] = images # # return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) from detectron2.config import LazyConfig, instantiate def setup(args): """ Create configs and perform basic setups. """ cfg1 = LazyConfig.load(args.config_file_gd) cfg2 = LazyConfig.load(args.config_file_it) # import pdb;pdb.set_trace() opt1,opt2=args.opt.split(';') opt1=opt1.split(',') opt2=opt2.split(',') cfg1 = LazyConfig.apply_overrides(cfg1, opt1) cfg2 = LazyConfig.apply_overrides(cfg2, opt2) # cfg.freeze() # default_setup(cfg, args) # setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") return cfg1,cfg2 def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) cfg,cfg2=setup(model_args) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM_joint_2st_it_only_ref_instr.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path,cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir="/comp_robot/zhanghao/.cache/hugging_face/", model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter model.config.tune_prompt_adapter = training_args.tune_prompt_adapter = model_args.tune_prompt_adapter if model_args.tune_mm_mlp_adapter or training_args.dbg: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) cfg.MODEL.DIM_PROJ=model.get_model().config.hidden_size model.initialize_seg_modules( cfg=cfg ) if model_args.tune_prompt_adapter: model.requires_grad_(False) model.freeze_seg_modules() model.initialize_interactive_modules(cfg=cfg2,model_args=model_args) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) print(model) if model_args.load_model: loaded_dict = dict() if "stage1" in model_args.whole_model: old_emb_in=model.get_input_embeddings().weight.clone() old_emb_out=model.get_output_embeddings().weight.clone() for model_file in os.listdir(model_args.whole_model): if model_file.endswith('.bin') and model_file.startswith('pytorch_model'): loaded_dict.update(torch.load(os.path.join(model_args.whole_model, model_file), map_location='cpu')) model.load_state_dict(loaded_dict, strict=False) if "stage1" in model_args.whole_model: with torch.no_grad(): model.get_input_embeddings().weight[:-3]=old_emb_in[:-3] model.get_output_embeddings().weight[:-3]=old_emb_out[:-3] print(loaded_dict.keys()) training_args.train_interactive = True trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args,cfg=cfg,data_loader_args=(tokenizer, data_args,preprocess), **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_mem.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() from llava.train.train import train if __name__ == "__main__": train() ================================================ FILE: llava/utils.py ================================================ import datetime import logging import logging.handlers import os import sys import requests from llava.constants import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "llava" version = "1.0.1" description = "Towards GPT-4 like large language and visual assistant." readme = "README.md" requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "einops", "fastapi", "gradio==3.39.0", "markdown2[all]", "numpy", "requests", "sentencepiece", "tokenizers>=0.12.1", "torch", "torchvision", "uvicorn", "wandb", "shortuuid", "httpx==0.24.0", "deepspeed==0.9.5", "peft==0.4.0", "transformers==4.31.0", "accelerate==0.21.0", "bitsandbytes==0.41.0", "scikit-learn==1.2.2", "sentencepiece==0.1.99", "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "gradio_client==0.7.0" ] [project.urls] "Homepage" = "https://llava-vl.github.io" "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" [tool.setuptools.packages.find] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] [tool.wheel] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] ================================================ FILE: scripts/convert_sqa_to_llava.py ================================================ import json import os import fire import re from convert_sqa_to_llava_base_prompt import build_prompt_chatbot def convert_to_llava(base_dir, split, prompt_format="QCM-LEPA"): split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) split_problems = build_prompt_chatbot( problems, split_indices, prompt_format, use_caption=False, is_test=False) target_format = [] for prob_id, (input, output) in split_problems.items(): if input.startswith('Question: '): input = input.replace('Question: ', '') if output.startswith('Answer: '): output = output.replace('Answer: ', '') raw_prob_data = problems[prob_id] if raw_prob_data['image'] is None: target_format.append({ "id": prob_id, "conversations": [ {'from': 'human', 'value': f"{input}"}, {'from': 'gpt', 'value': f"{output}"}, ], }) else: target_format.append({ "id": prob_id, "image": os.path.join(prob_id, raw_prob_data['image']), "conversations": [ {'from': 'human', 'value': f"{input}\n"}, {'from': 'gpt', 'value': f"{output}"}, ], }) print(f'Number of samples: {len(target_format)}') with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: json.dump(target_format, f, indent=2) def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] problems = json.load(open(os.path.join(base_dir, "problems.json"))) split_problems = build_prompt_chatbot( problems, split_indices, prompt_format, use_caption=False, is_test=False) writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") for prob_id, (input, output) in split_problems.items(): if input.startswith('Question: '): input = input.replace('Question: ', '') if output.startswith('Answer: '): output = output.replace('Answer: ', '') raw_prob_data = problems[prob_id] if raw_prob_data['image'] is None: data = { "id": prob_id, "instruction": f"{input}", "output": f"{output}", } else: data = { "id": prob_id, "image": os.path.join(prob_id, raw_prob_data['image']), "instruction": f"{input}\n", "output": f"{output}", } writer.write(json.dumps(data) + '\n') writer.close() def main(task, **kwargs): globals()[task](**kwargs) if __name__ == "__main__": fire.Fire(main) ================================================ FILE: scripts/convert_sqa_to_llava_base_prompt.py ================================================ def get_question_text(problem): question = problem['question'] return question def get_context_text(problem, use_caption): txt_context = problem['hint'] img_context = problem['caption'] if use_caption else "" context = " ".join([txt_context, img_context]).strip() if context == "": context = "N/A" return context def get_choice_text(probelm, options): choices = probelm['choices'] choice_list = [] for i, c in enumerate(choices): choice_list.append("({}) {}".format(options[i], c)) choice_txt = " ".join(choice_list) #print(choice_txt) return choice_txt def get_answer(problem, options): return options[problem['answer']] def get_lecture_text(problem): # \\n: GPT-3 can generate the lecture with more tokens. lecture = problem['lecture'].replace("\n", "\\n") return lecture def get_solution_text(problem): # \\n: GPT-3 can generate the solution with more tokens solution = problem['solution'].replace("\n", "\\n") return solution def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True): input_format, output_format = format.split("-") ## Inputs if input_format == "CQM": input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" elif input_format == "QCM": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" # upper bound experiment elif input_format == "QCML": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" elif input_format == "QCME": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" elif input_format == "QCMLE": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" elif input_format == "QCLM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" elif input_format == "QCEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" elif input_format == "QCLEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" # Outputs if test_example: output = "Answer:" elif output_format == 'A': output = f"Answer: The answer is {answer}." elif output_format == 'AL': output = f"Answer: The answer is {answer}. BECAUSE: {solution}" elif output_format == 'AE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" elif output_format == 'ALE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" elif output_format == 'AEL': output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" elif output_format == 'LA': output = f"Answer: {lecture} The answer is {answer}." elif output_format == 'EA': output = f"Answer: {solution} The answer is {answer}." elif output_format == 'LEA': output = f"Answer: {lecture} {solution} The answer is {answer}." elif output_format == 'ELA': output = f"Answer: {solution} {lecture} The answer is {answer}." elif output_format == 'LEPA': output = '' if len(lecture.strip()) > 0: output += f"LECTURE: {lecture}\n" if len(solution.strip()) > 0: output += f"SOLUTION: {solution}\n" output += '###\n' output += f"ANSWER: {answer}." input = input.replace(" ", " ").strip() output = output.replace(" ", " ").strip() if input.endswith("BECAUSE:"): input = input.replace("BECAUSE:", "").strip() if output.endswith("BECAUSE:"): output = output.replace("BECAUSE:", "").strip() return input, output def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True): input_format, output_format = format.split("-") ## Inputs if input_format == "CQM": input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" elif input_format == "QCM": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" # upper bound experiment elif input_format == "QCML": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" elif input_format == "QCME": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" elif input_format == "QCMLE": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" elif input_format == "QCLM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" elif input_format == "QCEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" elif input_format == "QCLEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" # Outputs if test_example: output = "Answer:" elif output_format == 'A': output = f"Answer: The answer is {answer}." elif output_format == 'AL': output = f"Answer: The answer is {answer}. BECAUSE: {solution}" elif output_format == 'AE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" elif output_format == 'ALE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" elif output_format == 'AEL': output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" elif output_format == 'LA': output = f"Answer: {lecture} The answer is {answer}." elif output_format == 'EA': output = f"Answer: {solution} The answer is {answer}." elif output_format == 'LEA': output = f"Answer: {lecture} {solution} The answer is {answer}." elif output_format == 'ELA': output = f"Answer: {solution} {lecture} The answer is {answer}." text = input + output text = text.replace(" ", " ").strip() if text.endswith("BECAUSE:"): text = text.replace("BECAUSE:", "").strip() return text def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True): input_format, output_format = format.split("-") ## Inputs if input_format == "CQM": input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" elif input_format == "QCM": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" # upper bound experiment elif input_format == "QCML": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" elif input_format == "QCME": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" elif input_format == "QCMLE": input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" elif input_format == "QCLM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" elif input_format == "QCEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" elif input_format == "QCLEM": input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" # Outputs if test_example: output = "Answer:" elif output_format == 'A': output = f"Answer: The answer is {answer}." elif output_format == 'AL': output = f"Answer: The answer is {answer}. BECAUSE: {solution}" elif output_format == 'AE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" elif output_format == 'ALE': output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" elif output_format == 'AEL': output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" elif output_format == 'LA': output = f"Answer: {lecture} The answer is {answer}." elif output_format == 'EA': output = f"Answer: {solution} The answer is {answer}." elif output_format == 'LEA': output = f"Answer: {lecture} {solution} The answer is {answer}." elif output_format == 'ELA': output = f"Answer: {solution} {lecture} The answer is {answer}." input = input.replace(" ", " ").strip() output = output.replace(" ", " ").strip() if output.endswith("BECAUSE:"): output = output.replace("BECAUSE:", "").strip() user_prompt = {"role": "user", "content": f"Can you explain {input}?"} assistant_prompt = {"role": "assistant", "content": f"{output}"} return user_prompt, assistant_prompt def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False): examples = {} for qid in shot_qids: question = get_question_text(problems[qid]) context = get_context_text(problems[qid], use_caption) choice = get_choice_text(problems[qid], options) answer = get_answer(problems[qid], options) lecture = get_lecture_text(problems[qid]).replace('\\n', '\n') solution = get_solution_text(problems[qid]).replace('\\n', '\n') train_example = create_one_example_chatbot(prompt_format, question, context, choice, answer, lecture, solution, test_example=is_test) examples[qid] = train_example return examples def build_prompt(problems, shot_qids, test_qid, args): examples = [] # n-shot training examples for qid in shot_qids: question = get_question_text(problems[qid]) context = get_context_text(problems[qid], args.use_caption) choice = get_choice_text(problems[qid], args.options) answer = get_answer(problems[qid], args.options) lecture = get_lecture_text(problems[qid]) solution = get_solution_text(problems[qid]) train_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False) examples.append(train_example) # test example question = get_question_text(problems[test_qid]) context = get_context_text(problems[test_qid], args.use_caption) choice = get_choice_text(problems[test_qid], args.options) answer = get_answer(problems[test_qid], args.options) lecture = get_lecture_text(problems[test_qid]) solution = get_solution_text(problems[test_qid]) test_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True) examples.append(test_example) # create the prompt input prompt_input = '\n\n'.join(examples) return prompt_input def build_prompt_gpt4(problems, shot_qids, test_qid, args): prompt_array = [{"role": "system", "content": "You are a helpful assistant."}] # n-shot training examples for qid in shot_qids: question = get_question_text(problems[qid]) context = get_context_text(problems[qid], args.use_caption) choice = get_choice_text(problems[qid], args.options) answer = get_answer(problems[qid], args.options) lecture = get_lecture_text(problems[qid]) solution = get_solution_text(problems[qid]) user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False) prompt_array.append(user_prompt) prompt_array.append(assistant_prompt) # test example question = get_question_text(problems[test_qid]) context = get_context_text(problems[test_qid], args.use_caption) choice = get_choice_text(problems[test_qid], args.options) answer = get_answer(problems[test_qid], args.options) lecture = get_lecture_text(problems[test_qid]) solution = get_solution_text(problems[test_qid]) user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True) prompt_array.append(user_prompt) prompt_array.append(assistant_prompt) return prompt_array ================================================ FILE: scripts/finetune.sh ================================================ # Uncomment and set the following variables correspondingly to run this script: ################## VICUNA ################## PROMPT_VERSION=v1 # MODEL_VERSION="vicuna-v1-3-7b" ################## VICUNA ################## ################## LLaMA-2 ################## # PROMPT_VERSION="llava_llama_2" # MODEL_VERSION="llama-2-7b-chat" ################## LLaMA-2 ################## out_dir=output/llava_grounding_stage2 load=output/llava_grounding_stage1 mkdir -p $out_dir echo $out_dir/log export DATASET=datasets/ num_gpu=8 bs=$(( 8 * $num_gpu )) deepspeed llava/train/train_joint_2st.py \ --deepspeed scripts/zero2.json \ --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \ --whole_model $load \ --load_model True \ --version $PROMPT_VERSION \ --data_path datasets/llava/annotations/llava_instruct_150k.json \ --image_folder datasets/coco/train2017/ \ --vision_tower openai/clip-vit-large-patch14 \ --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --bf16 True \ --output_dir $out_dir \ --num_train_epochs 1 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 1000 \ --save_total_limit 10 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2400 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb \ --max_steps 10000 \ --config_file \ configs/openseed/openseed_swint_lang_joint_2st.yaml \ --opt \ MODEL.DECODER.WEIGHT_MULTIPLIER=0.1,MODEL.DECODER.COST_CLASS_WEIGHT=4.0,flickr.TRAIN.BATCH_SIZE_TOTAL=6,coco_instruct.TEST.BATCH_SIZE_TOTAL=${bs},coco_instruct.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/openseed_o365.pt \ >> $out_dir/log 2>&1 ================================================ FILE: scripts/finetune_visual_prompt.sh ================================================ # Uncomment and set the following variables correspondingly to run this script: ################## VICUNA ################## PROMPT_VERSION=v1 # MODEL_VERSION="vicuna-v1-3-7b" ################## VICUNA ################## ################## LLaMA-2 ################## # PROMPT_VERSION="llava_llama_2" # MODEL_VERSION="llama-2-7b-chat" ################## LLaMA-2 ################## out_dir=output/llava_stage2_visual_prompt load=output/llava_grounding_stage2/ mkdir -p $out_dir echo $out_dir/log export DATASET=datasets/ num_gpu=8 bs=$(( 8 * $num_gpu )) deepspeed llava/train/train_joint_2st_interactive_refcoco_coco_instruction.py \ --deepspeed scripts/zero2.json \ --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \ --whole_model $load \ --load_model True \ --version $PROMPT_VERSION \ --data_path datasets/llava/annotations/llava_instruct_150k.json \ --image_folder datasets/coco/train2017/ \ --vision_tower openai/clip-vit-large-patch14 \ --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --tune_prompt_adapter True \ --mm_use_im_patch_token False \ --bf16 True \ --output_dir $out_dir \ --num_train_epochs 1 \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 1000 \ --save_total_limit 10 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2400 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb \ --max_steps 20000 \ --config_file_gd \ configs/openseed/openseed_swint_lang_joint_2st_visual_prompt.yaml \ --config_file_it \ configs/semsam/visual_prompt_encoder.yaml \ --opt \ "detach_seg=True,MODEL.DECODER.WEIGHT_MULTIPLIER=0.1,MODEL.DECODER.COST_CLASS_WEIGHT=4.0,flickr.TEST.BATCH_SIZE_TOTAL=${bs},flickr.TRAIN.BATCH_SIZE_TOTAL=${bs},coco_interactive.TRAIN.BATCH_SIZE_TOTAL=${bs},coco_instruct.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/openseed_o365.pt;MODEL.DECODER.WEIGHT_MULTIPLIER=0.2,coco_interactive.TEST.BATCH_SIZE_TOTAL=${bs},coco_interactive.TRAIN.BATCH_SIZE_TOTAL=${bs},MODEL.WEIGHTS=ckpts/visual_prompt_enc.pth" \ >> $out_dir/log 2>&1 ================================================ FILE: scripts/merge_lora_weights.py ================================================ import argparse from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path def merge_lora(args): model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') model.save_pretrained(args.save_model_path) tokenizer.save_pretrained(args.save_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-base", type=str, required=True) parser.add_argument("--save-model-path", type=str, required=True) args = parser.parse_args() merge_lora(args) ================================================ FILE: scripts/pretrain_joint.sh ================================================ # Uncomment and set the following variables correspondingly to run this script: # MODEL_VERSION=vicuna-v1-3-7b # MODEL_VERSION=llama-2-7b-chat ########### DO NOT CHANGE ########### ########### USE THIS FOR BOTH ########### PROMPT_VERSION=v1 ########### DO NOT CHANGE ########### out_dir=output/llava_grounding_stage1 mkdir -p $out_dir echo $out_dir/log export DATASET=datasets/ n_gpu=4 deepspeed --include=localhost:1,2,3,7 llava/train/train_joint_1st.py \ --deepspeed scripts/zero2.json \ --model_name_or_path ckpts/vicuna/vicuna-7b-v1.3/ \ --version $PROMPT_VERSION \ --data_path datasets/llava/annotations/cap600k_brackets_all.json \ --image_folder datasets/ConceptualCaptionsFiltered/ \ --vision_tower openai/clip-vit-large-patch14 \ --pretrain_mm_mlp_adapter output/llava_stage1/mm_projector.bin \ --tune_mm_mlp_adapter True \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --bf16 True \ --output_dir $out_dir \ --max_steps 30000 \ --num_train_epochs 1 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 1000 \ --save_total_limit 100 \ --learning_rate 1e-4 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb \ --config_file \ configs/openseed/openseed_swint_lang_joint.yaml \ --opt \ flickr.TRAIN.BATCH_SIZE_TOTAL=8,COCO.TRAIN.BATCH_SIZE_TOTAL=24,MODEL.WEIGHTS=ckpts/openseed_o365.pt \ >> $out_dir/log 2>&1 ================================================ FILE: utils/Config.py ================================================ from fvcore.common.config import CfgNode as _CfgNode class CfgNode(_CfgNode): """ The same as `fvcore.common.config.CfgNode`, but different in: 1. Use unsafe yaml loading by default. Note that this may lead to arbitrary code execution: you must not load a config file from untrusted sources before manually inspecting the content of the file. 2. Support config versioning. When attempting to merge an old config, it will convert the old config automatically. .. automethod:: clone .. automethod:: freeze .. automethod:: defrost .. automethod:: is_frozen .. automethod:: load_yaml_with_base .. automethod:: merge_from_list .. automethod:: merge_from_other_cfg """ def merge_from_dict(self, dict): pass node = CfgNode() ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/arguments.py ================================================ import yaml import json import argparse import logging logger = logging.getLogger(__name__) def load_config_dict_to_opt(opt, config_dict): """ Load the key, value pairs from config_dict to opt, overriding existing values in opt if there is any. """ if not isinstance(config_dict, dict): raise TypeError("Config must be a Python dictionary") for k, v in config_dict.items(): k_parts = k.split('.') pointer = opt for k_part in k_parts[:-1]: if k_part not in pointer: pointer[k_part] = {} pointer = pointer[k_part] assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." ori_value = pointer.get(k_parts[-1]) pointer[k_parts[-1]] = v if ori_value: logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") def load_opt_from_config_files(conf_files): """ Load opt from the config files, settings in later files can override those in previous files. Args: conf_files (list): a list of config file paths Returns: dict: a dictionary of opt settings """ opt = {} for conf_file in conf_files: with open(conf_file, encoding='utf-8') as f: config_dict = yaml.safe_load(f) load_config_dict_to_opt(opt, config_dict) return opt def load_opt_command(args): parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.') parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).') parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.') parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER) cmdline_args = parser.parse_args() if not args else parser.parse_args(args) opt = load_opt_from_config_files(cmdline_args.conf_files) if cmdline_args.config_overrides: config_overrides_string = ' '.join(cmdline_args.config_overrides) logger.warning(f"Command line config overrides: {config_overrides_string}") config_dict = json.loads(config_overrides_string) load_config_dict_to_opt(opt, config_dict) if cmdline_args.overrides: assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value" keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] types = [] for key in keys: key = key.split('.') ele = opt.copy() while len(key) > 0: ele = ele[key.pop(0)] types.append(type(ele)) config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} load_config_dict_to_opt(opt, config_dict) # combine cmdline_args into opt dictionary for key, val in cmdline_args.__dict__.items(): if val is not None: opt[key] = val return opt, cmdline_args def save_opt_to_json(opt, conf_file): with open(conf_file, 'w', encoding='utf-8') as f: json.dump(opt, f, indent=4) def save_opt_to_yaml(opt, conf_file): with open(conf_file, 'w', encoding='utf-8') as f: yaml.dump(opt, f) ================================================ FILE: utils/constants.py ================================================ IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] IMAGENET_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] IMAGENETS_919_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02790996', 'n02791124', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02834397', 'n02835271', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966687', 'n02971356', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03041632', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03075370', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03179701', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03649909', 'n03657121', 'n03658185', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03706229', 'n03709823', 'n03710193', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04086273', 'n04090263', 'n04099969', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04243546', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04376876', 'n04380533', 'n04389033', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04596742', 'n04597913', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06596364', 'n06794110', 'n06874185', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07930864', 'n07932039', 'n09229709', 'n09246464', 'n09256479', 'n09835506', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] IMAGENETS_FOLDER_NAMES = IMAGENETS_919_FOLDER_NAMES IMAGENETS_300_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01491361', 'n01494475', 'n01496331', 'n01518878', 'n01531178', 'n01532829', 'n01537544', 'n01608432', 'n01630670', 'n01632777', 'n01644373', 'n01644900', 'n01667114', 'n01675722', 'n01682714', 'n01685808', 'n01694178', 'n01695060', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01734418', 'n01744401', 'n01753488', 'n01770081', 'n01770393', 'n01773797', 'n01776313', 'n01817953', 'n01820546', 'n01855032', 'n01877812', 'n01882714', 'n01910747', 'n01914609', 'n01943899', 'n01980166', 'n01983481', 'n01984695', 'n01990800', 'n02011460', 'n02012849', 'n02013706', 'n02018795', 'n02058221', 'n02087046', 'n02088094', 'n02088632', 'n02090622', 'n02090721', 'n02091134', 'n02091244', 'n02093256', 'n02093754', 'n02094433', 'n02095570', 'n02097130', 'n02097209', 'n02097298', 'n02098413', 'n02100735', 'n02101556', 'n02102040', 'n02102177', 'n02104029', 'n02105412', 'n02107142', 'n02107312', 'n02110063', 'n02110958', 'n02111129', 'n02111500', 'n02111889', 'n02112706', 'n02113186', 'n02114855', 'n02119022', 'n02119789', 'n02120505', 'n02123394', 'n02123597', 'n02125311', 'n02127052', 'n02129604', 'n02133161', 'n02134418', 'n02165456', 'n02169497', 'n02177972', 'n02206856', 'n02256656', 'n02259212', 'n02268853', 'n02277742', 'n02280649', 'n02281406', 'n02321529', 'n02325366', 'n02326432', 'n02342885', 'n02396427', 'n02398521', 'n02415577', 'n02417914', 'n02447366', 'n02457408', 'n02480495', 'n02483362', 'n02488702', 'n02493793', 'n02494079', 'n02497673', 'n02504013', 'n02504458', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02669723', 'n02672831', 'n02690373', 'n02701002', 'n02708093', 'n02747177', 'n02769748', 'n02782093', 'n02783161', 'n02790996', 'n02793495', 'n02804610', 'n02808304', 'n02814533', 'n02835271', 'n02841315', 'n02859443', 'n02869837', 'n02870880', 'n02879718', 'n02892201', 'n02895154', 'n02917067', 'n02950826', 'n02951585', 'n02966687', 'n02978881', 'n02992529', 'n03000684', 'n03014705', 'n03018349', 'n03047690', 'n03075370', 'n03095699', 'n03109150', 'n03127925', 'n03133878', 'n03197337', 'n03201208', 'n03207941', 'n03223299', 'n03259280', 'n03271574', 'n03272562', 'n03291819', 'n03337140', 'n03376595', 'n03379051', 'n03393912', 'n03394916', 'n03404251', 'n03417042', 'n03443371', 'n03445777', 'n03452741', 'n03478589', 'n03482405', 'n03492542', 'n03494278', 'n03496892', 'n03532672', 'n03535780', 'n03538406', 'n03584829', 'n03590841', 'n03630383', 'n03633091', 'n03658185', 'n03673027', 'n03710193', 'n03743016', 'n03763968', 'n03764736', 'n03775546', 'n03781244', 'n03786901', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03794056', 'n03814906', 'n03825788', 'n03840681', 'n03874599', 'n03877845', 'n03888605', 'n03891251', 'n03903868', 'n03908714', 'n03929855', 'n03938244', 'n03956157', 'n03958227', 'n03976467', 'n03976657', 'n03991062', 'n04026417', 'n04033995', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04069434', 'n04070727', 'n04090263', 'n04099969', 'n04116512', 'n04118776', 'n04120489', 'n04179913', 'n04192698', 'n04201297', 'n04228054', 'n04229816', 'n04243546', 'n04254120', 'n04254680', 'n04254777', 'n04263257', 'n04265275', 'n04275548', 'n04277352', 'n04285008', 'n04311004', 'n04317175', 'n04335435', 'n04347754', 'n04371430', 'n04376876', 'n04380533', 'n04389033', 'n04399382', 'n04404412', 'n04429376', 'n04435653', 'n04447861', 'n04479046', 'n04483307', 'n04505470', 'n04507155', 'n04522168', 'n04540053', 'n04550184', 'n04552348', 'n04554684', 'n04557648', 'n04562935', 'n04579145', 'n04584207', 'n04591713', 'n04596742', 'n04606251', 'n04612504', 'n04613696', 'n06794110', 'n06874185', 'n07584110', 'n07614500', 'n07693725', 'n07697313', 'n07697537', 'n07711569', 'n07716906', 'n07720875', 'n07730033', 'n07742313', 'n07745940', 'n07749582', 'n07831146', 'n07880968', 'n07930864', 'n09256479', 'n12057211', 'n12768682', 'n12998815', 'n13037406', 'n13044778', 'n13054560'] IMAGENETS_50_FOLDER_NAMES = ['n01443537', 'n01491361', 'n01531178', 'n01644373', 'n02104029', 'n02119022', 'n02123597', 'n02133161', 'n02165456', 'n02281406', 'n02325366', 'n02342885', 'n02396427', 'n02483362', 'n02504458', 'n02510455', 'n02690373', 'n02747177', 'n02783161', 'n02814533', 'n02859443', 'n02917067', 'n02992529', 'n03014705', 'n03047690', 'n03095699', 'n03197337', 'n03201208', 'n03445777', 'n03452741', 'n03584829', 'n03630383', 'n03775546', 'n03791053', 'n03874599', 'n03891251', 'n04026417', 'n04335435', 'n04380533', 'n04404412', 'n04447861', 'n04507155', 'n04522168', 'n04557648', 'n04562935', 'n04612504', 'n06794110', 'n07749582', 'n07831146', 'n12998815'] IMAGENET_DEFAULT_TEMPLATES = [ '{}.', 'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', ] IMAGENET_SIMPLE_TEMPLATES = [ 'a photo of {}.', ] COCO_INSTANCE_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] COCO_PANOPTIC_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged'] COCO_IMAGENET_INDEX_PAIR = [[0, 591], [1, 444], [2, 479], [3, 671], [4, 405], [5, 779], [6, 466], [7, 717], [8, 472], [9, 920], [10, 686], [11, 920], [12, 704], [13, 708], [14, 134], [15, 282], [16, 215], [17, 291], [18, 349], [19, 341], [20, 366], [21, 294], [22, 340], [23, 286], [24, 414], [25, 879], [26, 636], [27, 906], [28, 519], [29, 162], [30, 795], [31, 671], [32, 852], [33, 405], [34, 805], [35, 560], [36, 671], [37, 671], [38, 752], [39, 907], [40, 907], [41, 787], [42, 792], [43, 792], [44, 792], [45, 538], [46, 954], [47, 948], [48, 697], [49, 950], [50, 937], [51, 936], [52, 934], [53, 963], [54, 931], [55, 415], [56, 423], [57, 831], [58, 738], [59, 520], [60, 532], [61, 896], [62, 851], [63, 620], [64, 674], [65, 761], [66, 508], [67, 487], [68, 447], [69, 827], [70, 859], [71, 896], [72, 760], [73, 454], [74, 892], [75, 883], [76, 591], [77, 850], [78, 584], [79, 696], [80, 879], [81, 672], [82, 839], [83, 519], [84, 896], [85, 854], [86, 799], [87, 858], [88, 947], [89, 953], [90, 792], [91, 425], [92, 862], [93, 475], [94, 916], [95, 721], [96, 538], [97, 703], [98, 705], [99, 979], [100, 888], [101, 538], [102, 774], [103, 390], [104, 519], [105, 803], [106, 708], [107, 672], [108, 879], [109, 825], [110, 825], [111, 858], [112, 825], [113, 898], [114, 904], [115, 904], [116, 947], [117, 862], [118, 858], [119, 862], [120, 648], [121, 736], [122, 904], [123, 708], [124, 970], [125, 936], [126, 792], [127, 549], [128, 712], [129, 647], [130, 972], [131, 904], [132, 824]] PASCAL_CLASSES = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] # PASCAL_CLASSES = [ # "airplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", # "chair", "cow", "dining table", "dog", "horse", "motorcycle", "person", # "potted plant", "sheep", "couch", "train", "tv" # ] PASCAL_LABELS = [ [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128], ] # ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window ', 'grass', 'cabinet', 'sidewalk', 'person', 'ground', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'picture', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'closet', 'light', 'tub', 'rail', 'cushion', 'pedestal', 'box', 'column', 'signboard', 'dresser', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'pub', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'ceiling light', 'awning', 'street light', 'booth', 'tv', 'airplane', 'dirt road', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'transporter', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'cylinder', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'stair', 'storage tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'exhaust hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'tvmonitor', 'bulletin board', 'shower', 'heater', 'drinking glass', 'clock', 'flag'] # ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window ', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'closet', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', 'dirt road', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'] ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'plane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'] COCO_INTER_ADE = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'pavement', 'person', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'shelf', 'house', 'sea', 'mirror', 'rug', 'fence', 'rock', 'stone', 'sign', 'counter', 'sand', 'sink', 'refrigerator', 'stairs', 'table', 'pillow', 'door', 'river', 'bridge', 'blind', 'table', 'toilet', 'flower', 'book', 'bench', 'tree', 'chair', 'boat', 'bus', 'towel', 'light', 'truck', 'tv', 'dirt', 'bottle', 'counter', 'tent', 'oven', 'ball', 'food', 'microwave', 'bicycle', 'blanket', 'vase', 'traffic', 'light', 'glass', 'clock'] PASCAL_CONTEXT_459 = ['accordion', 'aeroplane', 'air conditioner', 'antenna', 'artillery', 'ashtray', 'atrium', 'baby carriage', 'bag', 'ball', 'balloon', 'bamboo weaving', 'barrel', 'baseball bat', 'basket', 'basketball backboard', 'bathtub', 'bed', 'bedclothes', 'beer', 'bell', 'bench', 'bicycle', 'binoculars', 'bird', 'bird cage', 'bird feeder', 'bird nest', 'blackboard', 'board', 'boat', 'bone', 'book', 'bottle', 'bottle opener', 'bowl', 'box', 'bracelet', 'brick', 'bridge', 'broom', 'brush', 'bucket', 'building', 'bus', 'cabinet', 'cabinet door', 'cage', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camera lens', 'can', 'candle', 'candle holder', 'cap', 'car', 'card', 'cart', 'case', 'casette recorder', 'cash register', 'cat', 'cd', 'cd player', 'ceiling', 'cell phone', 'cello', 'chain', 'chair', 'chessboard', 'chicken', 'chopstick', 'clip', 'clippers', 'clock', 'closet', 'cloth', 'clothes tree', 'coffee', 'coffee machine', 'comb', 'computer', 'concrete', 'cone', 'container', 'control booth', 'controller', 'cooker', 'copying machine', 'coral', 'cork', 'corkscrew', 'counter', 'court', 'cow', 'crabstick', 'crane', 'crate', 'cross', 'crutch', 'cup', 'curtain', 'cushion', 'cutting board', 'dais', 'disc', 'disc case', 'dishwasher', 'dock', 'dog', 'dolphin', 'door', 'drainer', 'dray', 'drink dispenser', 'drinking machine', 'drop', 'drug', 'drum', 'drum kit', 'duck', 'dumbbell', 'earphone', 'earrings', 'egg', 'electric fan', 'electric iron', 'electric pot', 'electric saw', 'electronic keyboard', 'engine', 'envelope', 'equipment', 'escalator', 'exhibition booth', 'extinguisher', 'eyeglass', 'fan', 'faucet', 'fax machine', 'fence', 'ferris wheel', 'fire extinguisher', 'fire hydrant', 'fire place', 'fish', 'fish tank', 'fishbowl', 'fishing net', 'fishing pole', 'flag', 'flagstaff', 'flame', 'flashlight', 'floor', 'flower', 'fly', 'foam', 'food', 'footbridge', 'forceps', 'fork', 'forklift', 'fountain', 'fox', 'frame', 'fridge', 'frog', 'fruit', 'funnel', 'furnace', 'game controller', 'game machine', 'gas cylinder', 'gas hood', 'gas stove', 'gift box', 'glass', 'glass marble', 'globe', 'glove', 'goal', 'grandstand', 'grass', 'gravestone', 'ground', 'guardrail', 'guitar', 'gun', 'hammer', 'hand cart', 'handle', 'handrail', 'hanger', 'hard disk drive', 'hat', 'hay', 'headphone', 'heater', 'helicopter', 'helmet', 'holder', 'hook', 'horse', 'horse-drawn carriage', 'hot-air balloon', 'hydrovalve', 'ice', 'inflator pump', 'ipod', 'iron', 'ironing board', 'jar', 'kart', 'kettle', 'key', 'keyboard', 'kitchen range', 'kite', 'knife', 'knife block', 'ladder', 'ladder truck', 'ladle', 'laptop', 'leaves', 'lid', 'life buoy', 'light', 'light bulb', 'lighter', 'line', 'lion', 'lobster', 'lock', 'machine', 'mailbox', 'mannequin', 'map', 'mask', 'mat', 'match book', 'mattress', 'menu', 'metal', 'meter box', 'microphone', 'microwave', 'mirror', 'missile', 'model', 'money', 'monkey', 'mop', 'motorbike', 'mountain', 'mouse', 'mouse pad', 'musical instrument', 'napkin', 'net', 'newspaper', 'oar', 'ornament', 'outlet', 'oven', 'oxygen bottle', 'pack', 'pan', 'paper', 'paper box', 'paper cutter', 'parachute', 'parasol', 'parterre', 'patio', 'pelage', 'pen', 'pen container', 'pencil', 'person', 'photo', 'piano', 'picture', 'pig', 'pillar', 'pillow', 'pipe', 'pitcher', 'plant', 'plastic', 'plate', 'platform', 'player', 'playground', 'pliers', 'plume', 'poker', 'poker chip', 'pole', 'pool table', 'postcard', 'poster', 'pot', 'pottedplant', 'printer', 'projector', 'pumpkin', 'rabbit', 'racket', 'radiator', 'radio', 'rail', 'rake', 'ramp', 'range hood', 'receiver', 'recorder', 'recreational machines', 'remote control', 'road', 'robot', 'rock', 'rocket', 'rocking horse', 'rope', 'rug', 'ruler', 'runway', 'saddle', 'sand', 'saw', 'scale', 'scanner', 'scissors', 'scoop', 'screen', 'screwdriver', 'sculpture', 'scythe', 'sewer', 'sewing machine', 'shed', 'sheep', 'shell', 'shelves', 'shoe', 'shopping cart', 'shovel', 'sidecar', 'sidewalk', 'sign', 'signal light', 'sink', 'skateboard', 'ski', 'sky', 'sled', 'slippers', 'smoke', 'snail', 'snake', 'snow', 'snowmobiles', 'sofa', 'spanner', 'spatula', 'speaker', 'speed bump', 'spice container', 'spoon', 'sprayer', 'squirrel', 'stage', 'stair', 'stapler', 'stick', 'sticky note', 'stone', 'stool', 'stove', 'straw', 'stretcher', 'sun', 'sunglass', 'sunshade', 'surveillance camera', 'swan', 'sweeper', 'swim ring', 'swimming pool', 'swing', 'switch', 'table', 'tableware', 'tank', 'tap', 'tape', 'tarp', 'telephone', 'telephone booth', 'tent', 'tire', 'toaster', 'toilet', 'tong', 'tool', 'toothbrush', 'towel', 'toy', 'toy car', 'track', 'train', 'trampoline', 'trash bin', 'tray', 'tree', 'tricycle', 'tripod', 'trophy', 'truck', 'tube', 'turtle', 'tvmonitor', 'tweezers', 'typewriter', 'umbrella', 'unknown', 'vacuum cleaner', 'vending machine', 'video camera', 'video game console', 'video player', 'video tape', 'violin', 'wakeboard', 'wall', 'wallet', 'wardrobe', 'washing machine', 'watch', 'water', 'water dispenser', 'water pipe', 'water skate board', 'watermelon', 'whale', 'wharf', 'wheel', 'wheelchair', 'window', 'window blinds', 'wineglass', 'wire', 'wood', 'wool'] PASCAL_CONTEXT_33 = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'sky', 'grass', 'ground', 'road', 'building', 'tree', 'water', 'mountain', 'wall', 'floor', 'track', 'keyboard', 'ceiling'] PASCAL_CONTEXT_59 = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv monitor', 'bag', 'bed', 'bench', 'book', 'building', 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 'sign', 'plate', 'road', 'rock', 'shelves', 'side walk', 'sky', 'snow', 'bed clothes', 'track', 'tree', 'truck', 'wall', 'water', 'window', 'wood'] SUN_RGBD_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag'] SCAN_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag'] SCAN_40 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop'] SCAN_20 = ["wall", "floor", "cabinet", "bed", "chair", "sofa", "table", "door", "window", "bookshelf", "picture", "counter", "desk", "curtain", "refrigerator", "shower curtain", "toilet", "sink", "bathtub", "otherfurniture"] CITYSCAPES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] CITYSCAPES_THING = ["person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"] BDD_SEM = ["road", "sidewalk", "building", "wall", "fence", "pole", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"] BDD_PANO = ['dynamic', 'ego vehicle', 'ground', 'static', 'parking', 'rail track', 'road', 'sidewalk', 'bridge', 'building', 'fence', 'garage', 'guard rail', 'tunnel', 'wall', 'banner', 'billboard', 'lane divider', 'parking sign', 'pole', 'polegroup', 'street light', 'traffic cone', 'traffic device', 'traffic light', 'traffic sign', 'traffic sign frame', 'terrain', 'vegetation', 'sky', 'person', 'rider', 'bicycle', 'bus', 'car', 'caravan', 'motorcycle', 'trailer', 'train', 'truck'] OBJECT365 = ['Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat', 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard', 'Gun', 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon', 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak', 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate', 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis '] OPENIMAGE = ['Tortoise', 'Container', 'Magpie', 'Sea turtle', 'Football', 'Ambulance', 'Ladder', 'Toothbrush', 'Syringe', 'Sink', 'Toy', 'Organ (Musical Instrument)', 'Cassette deck', 'Apple', 'Human eye', 'Cosmetics', 'Paddle', 'Snowman', 'Beer', 'Chopsticks', 'Human beard', 'Bird', 'Parking meter', 'Traffic light', 'Croissant', 'Cucumber', 'Radish', 'Towel', 'Doll', 'Skull', 'Washing machine', 'Glove', 'Tick', 'Belt', 'Sunglasses', 'Banjo', 'Cart', 'Ball', 'Backpack', 'Bicycle', 'Home appliance', 'Centipede', 'Boat', 'Surfboard', 'Boot', 'Headphones', 'Hot dog', 'Shorts', 'Fast food', 'Bus', 'Boy', 'Screwdriver', 'Bicycle wheel', 'Barge', 'Laptop', 'Miniskirt', 'Drill (Tool)', 'Dress', 'Bear', 'Waffle', 'Pancake', 'Brown bear', 'Woodpecker', 'Blue jay', 'Pretzel', 'Bagel', 'Tower', 'Teapot', 'Person', 'Bow and arrow', 'Swimwear', 'Beehive', 'Brassiere', 'Bee', 'Bat (Animal)', 'Starfish', 'Popcorn', 'Burrito', 'Chainsaw', 'Balloon', 'Wrench', 'Tent', 'Vehicle registration plate', 'Lantern', 'Toaster', 'Flashlight', 'Billboard', 'Tiara', 'Limousine', 'Necklace', 'Carnivore', 'Scissors', 'Stairs', 'Computer keyboard', 'Printer', 'Traffic sign', 'Chair', 'Shirt', 'Poster', 'Cheese', 'Sock', 'Fire hydrant', 'Land vehicle', 'Earrings', 'Tie', 'Watercraft', 'Cabinetry', 'Suitcase', 'Muffin', 'Bidet', 'Snack', 'Snowmobile', 'Clock', 'Medical equipment', 'Cattle', 'Cello', 'Jet ski', 'Camel', 'Coat', 'Suit', 'Desk', 'Cat', 'Bronze sculpture', 'Juice', 'Gondola', 'Beetle', 'Cannon', 'Computer mouse', 'Cookie', 'Office building', 'Fountain', 'Coin', 'Calculator', 'Cocktail', 'Computer monitor', 'Box', 'Stapler', 'Christmas tree', 'Cowboy hat', 'Hiking equipment', 'Studio couch', 'Drum', 'Dessert', 'Wine rack', 'Drink', 'Zucchini', 'Ladle', 'Human mouth', 'Dairy Product', 'Dice', 'Oven', 'Dinosaur', 'Ratchet (Device)', 'Couch', 'Cricket ball', 'Winter melon', 'Spatula', 'Whiteboard', 'Pencil sharpener', 'Door', 'Hat', 'Shower', 'Eraser', 'Fedora', 'Guacamole', 'Dagger', 'Scarf', 'Dolphin', 'Sombrero', 'Tin can', 'Mug', 'Tap', 'Harbor seal', 'Stretcher', 'Can opener', 'Goggles', 'Human body', 'Roller skates', 'Coffee cup', 'Cutting board', 'Blender', 'Plumbing fixture', 'Stop sign', 'Office supplies', 'Volleyball (Ball)', 'Vase', 'Slow cooker', 'Wardrobe', 'Coffee', 'Whisk', 'Paper towel', 'Personal care', 'Food', 'Sun hat', 'Tree house', 'Flying disc', 'Skirt', 'Gas stove', 'Salt and pepper shakers', 'Mechanical fan', 'Face powder', 'Fax', 'Fruit', 'French fries', 'Nightstand', 'Barrel', 'Kite', 'Tart', 'Treadmill', 'Fox', 'Flag', 'French horn', 'Window blind', 'Human foot', 'Golf cart', 'Jacket', 'Egg (Food)', 'Street light', 'Guitar', 'Pillow', 'Human leg', 'Isopod', 'Grape', 'Human ear', 'Power plugs and sockets', 'Panda', 'Giraffe', 'Woman', 'Door handle', 'Rhinoceros', 'Bathtub', 'Goldfish', 'Houseplant', 'Goat', 'Baseball bat', 'Baseball glove', 'Mixing bowl', 'Marine invertebrates', 'Kitchen utensil', 'Light switch', 'House', 'Horse', 'Stationary bicycle', 'Hammer', 'Ceiling fan', 'Sofa bed', 'Adhesive tape', 'Harp', 'Sandal', 'Bicycle helmet', 'Saucer', 'Harpsichord', 'Human hair', 'Heater', 'Harmonica', 'Hamster', 'Curtain', 'Bed', 'Kettle', 'Fireplace', 'Scale', 'Drinking straw', 'Insect', 'Hair dryer', 'Kitchenware', 'Indoor rower', 'Invertebrate', 'Food processor', 'Bookcase', 'Refrigerator', 'Wood-burning stove', 'Punching bag', 'Common fig', 'Cocktail shaker', 'Jaguar (Animal)', 'Golf ball', 'Fashion accessory', 'Alarm clock', 'Filing cabinet', 'Artichoke', 'Table', 'Tableware', 'Kangaroo', 'Koala', 'Knife', 'Bottle', 'Bottle opener', 'Lynx', 'Lavender (Plant)', 'Lighthouse', 'Dumbbell', 'Human head', 'Bowl', 'Humidifier', 'Porch', 'Lizard', 'Billiard table', 'Mammal', 'Mouse', 'Motorcycle', 'Musical instrument', 'Swim cap', 'Frying pan', 'Snowplow', 'Bathroom cabinet', 'Missile', 'Bust', 'Man', 'Waffle iron', 'Milk', 'Ring binder', 'Plate', 'Mobile phone', 'Baked goods', 'Mushroom', 'Crutch', 'Pitcher (Container)', 'Mirror', 'Personal flotation device', 'Table tennis racket', 'Pencil case', 'Musical keyboard', 'Scoreboard', 'Briefcase', 'Kitchen knife', 'Nail (Construction)', 'Tennis ball', 'Plastic bag', 'Oboe', 'Chest of drawers', 'Ostrich', 'Piano', 'Girl', 'Plant', 'Potato', 'Hair spray', 'Sports equipment', 'Pasta', 'Penguin', 'Pumpkin', 'Pear', 'Infant bed', 'Polar bear', 'Mixer', 'Cupboard', 'Jacuzzi', 'Pizza', 'Digital clock', 'Pig', 'Reptile', 'Rifle', 'Lipstick', 'Skateboard', 'Raven', 'High heels', 'Red panda', 'Rose', 'Rabbit', 'Sculpture', 'Saxophone', 'Shotgun', 'Seafood', 'Submarine sandwich', 'Snowboard', 'Sword', 'Picture frame', 'Sushi', 'Loveseat', 'Ski', 'Squirrel', 'Tripod', 'Stethoscope', 'Submarine', 'Scorpion', 'Segway', 'Training bench', 'Snake', 'Coffee table', 'Skyscraper', 'Sheep', 'Television', 'Trombone', 'Tea', 'Tank', 'Taco', 'Telephone', 'Torch', 'Tiger', 'Strawberry', 'Trumpet', 'Tree', 'Tomato', 'Train', 'Tool', 'Picnic basket', 'Cooking spray', 'Trousers', 'Bowling equipment', 'Football helmet', 'Truck', 'Measuring cup', 'Coffeemaker', 'Violin', 'Vehicle', 'Handbag', 'Paper cutter', 'Wine', 'Weapon', 'Wheel', 'Worm', 'Wok', 'Whale', 'Zebra', 'Auto part', 'Jug', 'Pizza cutter', 'Cream', 'Monkey', 'Lion', 'Bread', 'Platter', 'Chicken', 'Eagle', 'Helicopter', 'Owl', 'Duck', 'Turtle', 'Hippopotamus', 'Crocodile', 'Toilet', 'Toilet paper', 'Squid', 'Clothing', 'Footwear', 'Lemon', 'Spider', 'Deer', 'Frog', 'Banana', 'Rocket', 'Wine glass', 'Countertop', 'Tablet computer', 'Waste container', 'Swimming pool', 'Dog', 'Book', 'Elephant', 'Shark', 'Candle', 'Leopard', 'Axe', 'Hand dryer', 'Soap dispenser', 'Porcupine', 'Flower', 'Canary', 'Cheetah', 'Palm tree', 'Hamburger', 'Maple', 'Building', 'Fish', 'Lobster', 'Garden Asparagus', 'Furniture', 'Hedgehog', 'Airplane', 'Spoon', 'Otter', 'Bull', 'Oyster', 'Horizontal bar', 'Convenience store', 'Bomb', 'Bench', 'Ice cream', 'Caterpillar', 'Butterfly', 'Parachute', 'Orange', 'Antelope', 'Beaker', 'Moths and butterflies', 'Window', 'Closet', 'Castle', 'Jellyfish', 'Goose', 'Mule', 'Swan', 'Peach', 'Coconut', 'Seat belt', 'Raccoon', 'Chisel', 'Fork', 'Lamp', 'Camera', 'Squash (Plant)', 'Racket', 'Human face', 'Human arm', 'Vegetable', 'Diaper', 'Unicycle', 'Falcon', 'Chime', 'Snail', 'Shellfish', 'Cabbage', 'Carrot', 'Mango', 'Jeans', 'Flowerpot', 'Pineapple', 'Drawer', 'Stool', 'Envelope', 'Cake', 'Dragonfly', 'Common sunflower', 'Microwave oven', 'Honeycomb', 'Marine mammal', 'Sea lion', 'Ladybug', 'Shelf', 'Watch', 'Candy', 'Salad', 'Parrot', 'Handgun', 'Sparrow', 'Van', 'Grinder', 'Spice rack', 'Light bulb', 'Corded phone', 'Sports uniform', 'Tennis racket', 'Wall clock', 'Serving tray', 'Kitchen & dining room table', 'Dog bed', 'Cake stand', 'Cat furniture', 'Bathroom accessory', 'Facial tissue holder', 'Pressure cooker', 'Kitchen appliance', 'Tire', 'Ruler', 'Luggage and bags', 'Microphone', 'Broccoli', 'Umbrella', 'Pastry', 'Grapefruit', 'Band-aid', 'Animal', 'Bell pepper', 'Turkey', 'Lily', 'Pomegranate', 'Doughnut', 'Glasses', 'Human nose', 'Pen', 'Ant', 'Car', 'Aircraft', 'Human hand', 'Skunk', 'Teddy bear', 'Watermelon', 'Cantaloupe', 'Dishwasher', 'Flute', 'Balance beam', 'Sandwich', 'Shrimp', 'Sewing machine', 'Binoculars', 'Rays and skates', 'Ipod', 'Accordion', 'Willow', 'Crab', 'Crown', 'Seahorse', 'Perfume', 'Alpaca', 'Taxi', 'Canoe', 'Remote control', 'Wheelchair', 'Rugby ball', 'Armadillo', 'Maracas', 'Helmet'] ADE20K_847 = ['wall', 'building', 'sky', 'tree', 'road', 'floor', 'ceiling', 'bed', 'sidewalk', 'earth', 'cabinet', 'person', 'grass', 'windowpane', 'car', 'mountain', 'plant', 'table', 'chair', 'curtain', 'door', 'sofa', 'sea', 'painting', 'water', 'mirror', 'house', 'rug', 'shelf', 'armchair', 'fence', 'field', 'lamp', 'rock', 'seat', 'river', 'desk', 'bathtub', 'railing', 'signboard', 'cushion', 'path', 'work surface', 'stairs', 'column', 'sink', 'wardrobe', 'snow', 'refrigerator', 'base', 'bridge', 'blind', 'runway', 'cliff', 'sand', 'fireplace', 'pillow', 'screen door', 'toilet', 'skyscraper', 'grandstand', 'box', 'pool table', 'palm', 'double door', 'coffee table', 'counter', 'countertop', 'chest of drawers', 'kitchen island', 'boat', 'waterfall', 'stove', 'flower', 'bookcase', 'controls', 'book', 'stairway', 'streetlight', 'computer', 'bus', 'swivel chair', 'light', 'bench', 'case', 'towel', 'fountain', 'embankment', 'television receiver', 'van', 'hill', 'awning', 'poster', 'truck', 'airplane', 'pole', 'tower', 'court', 'ball', 'aircraft carrier', 'buffet', 'hovel', 'apparel', 'minibike', 'animal', 'chandelier', 'step', 'booth', 'bicycle', 'doorframe', 'sconce', 'pond', 'trade name', 'bannister', 'bag', 'traffic light', 'gazebo', 'escalator', 'land', 'board', 'arcade machine', 'eiderdown', 'bar', 'stall', 'playground', 'ship', 'ottoman', 'ashcan', 'bottle', 'cradle', 'pot', 'conveyer belt', 'train', 'stool', 'lake', 'tank', 'ice', 'basket', 'manhole', 'tent', 'canopy', 'microwave', 'barrel', 'dirt track', 'beam', 'dishwasher', 'plate', 'screen', 'ruins', 'washer', 'blanket', 'plaything', 'food', 'screen', 'oven', 'stage', 'beacon', 'umbrella', 'sculpture', 'aqueduct', 'container', 'scaffolding', 'hood', 'curb', 'roller coaster', 'horse', 'catwalk', 'glass', 'vase', 'central reservation', 'carousel', 'radiator', 'closet', 'machine', 'pier', 'fan', 'inflatable bounce game', 'pitch', 'paper', 'arcade', 'hot tub', 'helicopter', 'tray', 'partition', 'vineyard', 'bowl', 'bullring', 'flag', 'pot', 'footbridge', 'shower', 'bag', 'bulletin board', 'confessional booth', 'trunk', 'forest', 'elevator door', 'laptop', 'instrument panel', 'bucket', 'tapestry', 'platform', 'jacket', 'gate', 'monitor', 'telephone booth', 'spotlight', 'ring', 'control panel', 'blackboard', 'air conditioner', 'chest', 'clock', 'sand dune', 'pipe', 'vault', 'table football', 'cannon', 'swimming pool', 'fluorescent', 'statue', 'loudspeaker', 'exhibitor', 'ladder', 'carport', 'dam', 'pulpit', 'skylight', 'water tower', 'grill', 'display board', 'pane', 'rubbish', 'ice rink', 'fruit', 'patio', 'vending machine', 'telephone', 'net', 'backpack', 'jar', 'track', 'magazine', 'shutter', 'roof', 'banner', 'landfill', 'post', 'altarpiece', 'hat', 'arch', 'table game', 'bag', 'document', 'dome', 'pier', 'shanties', 'forecourt', 'crane', 'dog', 'piano', 'drawing', 'cabin', 'ad', 'amphitheater', 'monument', 'henhouse', 'cockpit', 'heater', 'windmill', 'pool', 'elevator', 'decoration', 'labyrinth', 'text', 'printer', 'mezzanine', 'mattress', 'straw', 'stalls', 'patio', 'billboard', 'bus stop', 'trouser', 'console table', 'rack', 'notebook', 'shrine', 'pantry', 'cart', 'steam shovel', 'porch', 'postbox', 'figurine', 'recycling bin', 'folding screen', 'telescope', 'deck chair', 'kennel', 'coffee maker', 'altar', 'fish', 'easel', 'artificial golf green', 'iceberg', 'candlestick', 'shower stall', 'television stand', 'wall socket', 'skeleton', 'grand piano', 'candy', 'grille door', 'pedestal', 'jersey', 'shoe', 'gravestone', 'shanty', 'structure', 'rocking chair', 'bird', 'place mat', 'tomb', 'big top', 'gas pump', 'lockers', 'cage', 'finger', 'bleachers', 'ferris wheel', 'hairdresser chair', 'mat', 'stands', 'aquarium', 'streetcar', 'napkin', 'dummy', 'booklet', 'sand trap', 'shop', 'table cloth', 'service station', 'coffin', 'drawer', 'cages', 'slot machine', 'balcony', 'volleyball court', 'table tennis', 'control table', 'shirt', 'merchandise', 'railway', 'parterre', 'chimney', 'can', 'tanks', 'fabric', 'alga', 'system', 'map', 'greenhouse', 'mug', 'barbecue', 'trailer', 'toilet tissue', 'organ', 'dishrag', 'island', 'keyboard', 'trench', 'basket', 'steering wheel', 'pitcher', 'goal', 'bread', 'beds', 'wood', 'file cabinet', 'newspaper', 'motorboat', 'rope', 'guitar', 'rubble', 'scarf', 'barrels', 'cap', 'leaves', 'control tower', 'dashboard', 'bandstand', 'lectern', 'switch', 'baseboard', 'shower room', 'smoke', 'faucet', 'bulldozer', 'saucepan', 'shops', 'meter', 'crevasse', 'gear', 'candelabrum', 'sofa bed', 'tunnel', 'pallet', 'wire', 'kettle', 'bidet', 'baby buggy', 'music stand', 'pipe', 'cup', 'parking meter', 'ice hockey rink', 'shelter', 'weeds', 'temple', 'patty', 'ski slope', 'panel', 'wallet', 'wheel', 'towel rack', 'roundabout', 'canister', 'rod', 'soap dispenser', 'bell', 'canvas', 'box office', 'teacup', 'trellis', 'workbench', 'valley', 'toaster', 'knife', 'podium', 'ramp', 'tumble dryer', 'fireplug', 'gym shoe', 'lab bench', 'equipment', 'rocky formation', 'plastic', 'calendar', 'caravan', 'check-in-desk', 'ticket counter', 'brush', 'mill', 'covered bridge', 'bowling alley', 'hanger', 'excavator', 'trestle', 'revolving door', 'blast furnace', 'scale', 'projector', 'soap', 'locker', 'tractor', 'stretcher', 'frame', 'grating', 'alembic', 'candle', 'barrier', 'cardboard', 'cave', 'puddle', 'tarp', 'price tag', 'watchtower', 'meters', 'light bulb', 'tracks', 'hair dryer', 'skirt', 'viaduct', 'paper towel', 'coat', 'sheet', 'fire extinguisher', 'water wheel', 'pottery', 'magazine rack', 'teapot', 'microphone', 'support', 'forklift', 'canyon', 'cash register', 'leaf', 'remote control', 'soap dish', 'windshield', 'cat', 'cue', 'vent', 'videos', 'shovel', 'eaves', 'antenna', 'shipyard', 'hen', 'traffic cone', 'washing machines', 'truck crane', 'cds', 'niche', 'scoreboard', 'briefcase', 'boot', 'sweater', 'hay', 'pack', 'bottle rack', 'glacier', 'pergola', 'building materials', 'television camera', 'first floor', 'rifle', 'tennis table', 'stadium', 'safety belt', 'cover', 'dish rack', 'synthesizer', 'pumpkin', 'gutter', 'fruit stand', 'ice floe', 'handle', 'wheelchair', 'mousepad', 'diploma', 'fairground ride', 'radio', 'hotplate', 'junk', 'wheelbarrow', 'stream', 'toll plaza', 'punching bag', 'trough', 'throne', 'chair desk', 'weighbridge', 'extractor fan', 'hanging clothes', 'dish', 'alarm clock', 'ski lift', 'chain', 'garage', 'mechanical shovel', 'wine rack', 'tramway', 'treadmill', 'menu', 'block', 'well', 'witness stand', 'branch', 'duck', 'casserole', 'frying pan', 'desk organizer', 'mast', 'spectacles', 'service elevator', 'dollhouse', 'hammock', 'clothes hanging', 'photocopier', 'notepad', 'golf cart', 'footpath', 'cross', 'baptismal font', 'boiler', 'skip', 'rotisserie', 'tables', 'water mill', 'helmet', 'cover curtain', 'brick', 'table runner', 'ashtray', 'street box', 'stick', 'hangers', 'cells', 'urinal', 'centerpiece', 'portable fridge', 'dvds', 'golf club', 'skirting board', 'water cooler', 'clipboard', 'camera', 'pigeonhole', 'chips', 'food processor', 'post box', 'lid', 'drum', 'blender', 'cave entrance', 'dental chair', 'obelisk', 'canoe', 'mobile', 'monitors', 'pool ball', 'cue rack', 'baggage carts', 'shore', 'fork', 'paper filer', 'bicycle rack', 'coat rack', 'garland', 'sports bag', 'fish tank', 'towel dispenser', 'carriage', 'brochure', 'plaque', 'stringer', 'iron', 'spoon', 'flag pole', 'toilet brush', 'book stand', 'water faucet', 'ticket office', 'broom', 'dvd', 'ice bucket', 'carapace', 'tureen', 'folders', 'chess', 'root', 'sewing machine', 'model', 'pen', 'violin', 'sweatshirt', 'recycling materials', 'mitten', 'chopping board', 'mask', 'log', 'mouse', 'grill', 'hole', 'target', 'trash bag', 'chalk', 'sticks', 'balloon', 'score', 'hair spray', 'roll', 'runner', 'engine', 'inflatable glove', 'games', 'pallets', 'baskets', 'coop', 'dvd player', 'rocking horse', 'buckets', 'bread rolls', 'shawl', 'watering can', 'spotlights', 'post-it', 'bowls', 'security camera', 'runner cloth', 'lock', 'alarm', 'side', 'roulette', 'bone', 'cutlery', 'pool balls', 'wheels', 'spice rack', 'plant pots', 'towel ring', 'bread box', 'video', 'funfair', 'breads', 'tripod', 'ironing board', 'skimmer', 'hollow', 'scratching post', 'tricycle', 'file box', 'mountain pass', 'tombstones', 'cooker', 'card game', 'golf bag', 'towel paper', 'chaise lounge', 'sun', 'toilet paper holder', 'rake', 'key', 'umbrella stand', 'dartboard', 'transformer', 'fireplace utensils', 'sweatshirts', 'cellular telephone', 'tallboy', 'stapler', 'sauna', 'test tube', 'palette', 'shopping carts', 'tools', 'push button', 'star', 'roof rack', 'barbed wire', 'spray', 'ear', 'sponge', 'racket', 'tins', 'eyeglasses', 'file', 'scarfs', 'sugar bowl', 'flip flop', 'headstones', 'laptop bag', 'leash', 'climbing frame', 'suit hanger', 'floor spotlight', 'plate rack', 'sewer', 'hard drive', 'sprinkler', 'tools box', 'necklace', 'bulbs', 'steel industry', 'club', 'jack', 'door bars', 'control panel', 'hairbrush', 'napkin holder', 'office', 'smoke detector', 'utensils', 'apron', 'scissors', 'terminal', 'grinder', 'entry phone', 'newspaper stand', 'pepper shaker', 'onions', 'central processing unit', 'tape', 'bat', 'coaster', 'calculator', 'potatoes', 'luggage rack', 'salt', 'street number', 'viewpoint', 'sword', 'cd', 'rowing machine', 'plug', 'andiron', 'pepper', 'tongs', 'bonfire', 'dog dish', 'belt', 'dumbbells', 'videocassette recorder', 'hook', 'envelopes', 'shower faucet', 'watch', 'padlock', 'swimming pool ladder', 'spanners', 'gravy boat', 'notice board', 'trash bags', 'fire alarm', 'ladle', 'stethoscope', 'rocket', 'funnel', 'bowling pins', 'valve', 'thermometer', 'cups', 'spice jar', 'night light', 'soaps', 'games table', 'slotted spoon', 'reel', 'scourer', 'sleeping robe', 'desk mat', 'dumbbell', 'hammer', 'tie', 'typewriter', 'shaker', 'cheese dish', 'sea star', 'racquet', 'butane gas cylinder', 'paper weight', 'shaving brush', 'sunglasses', 'gear shift', 'towel rail', 'adding machine'] LVIS_CATEGORIES = ['aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'horned_cow', 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', 'food_processor', 'football_(American)', 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', 'headboard', 'headlight', 'headscarf', 'headset', 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', 'newsstand', 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', 'recliner', 'record_player', 'reflector', 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', 'rolling_pin', 'root_beer', 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', 'telephone_pole', 'telephoto_lens', 'television_camera', 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', 'washbasin', 'automatic_washer', 'watch', 'water_bottle', 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'] ================================================ FILE: utils/constants_ori.py ================================================ COCO_PANOPTIC_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged'] ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'] ADE20K_847 = ['wall', 'building', 'sky', 'tree', 'road', 'floor', 'ceiling', 'bed', 'sidewalk', 'earth', 'cabinet', 'person', 'grass', 'windowpane', 'car', 'mountain', 'plant', 'table', 'chair', 'curtain', 'door', 'sofa', 'sea', 'painting', 'water', 'mirror', 'house', 'rug', 'shelf', 'armchair', 'fence', 'field', 'lamp', 'rock', 'seat', 'river', 'desk', 'bathtub', 'railing', 'signboard', 'cushion', 'path', 'work surface', 'stairs', 'column', 'sink', 'wardrobe', 'snow', 'refrigerator', 'base', 'bridge', 'blind', 'runway', 'cliff', 'sand', 'fireplace', 'pillow', 'screen door', 'toilet', 'skyscraper', 'grandstand', 'box', 'pool table', 'palm', 'double door', 'coffee table', 'counter', 'countertop', 'chest of drawers', 'kitchen island', 'boat', 'waterfall', 'stove', 'flower', 'bookcase', 'controls', 'book', 'stairway', 'streetlight', 'computer', 'bus', 'swivel chair', 'light', 'bench', 'case', 'towel', 'fountain', 'embankment', 'television receiver', 'van', 'hill', 'awning', 'poster', 'truck', 'airplane', 'pole', 'tower', 'court', 'ball', 'aircraft carrier', 'buffet', 'hovel', 'apparel', 'minibike', 'animal', 'chandelier', 'step', 'booth', 'bicycle', 'doorframe', 'sconce', 'pond', 'trade name', 'bannister', 'bag', 'traffic light', 'gazebo', 'escalator', 'land', 'board', 'arcade machine', 'eiderdown', 'bar', 'stall', 'playground', 'ship', 'ottoman', 'ashcan', 'bottle', 'cradle', 'pot', 'conveyer belt', 'train', 'stool', 'lake', 'tank', 'ice', 'basket', 'manhole', 'tent', 'canopy', 'microwave', 'barrel', 'dirt track', 'beam', 'dishwasher', 'plate', 'screen', 'ruins', 'washer', 'blanket', 'plaything', 'food', 'screen', 'oven', 'stage', 'beacon', 'umbrella', 'sculpture', 'aqueduct', 'container', 'scaffolding', 'hood', 'curb', 'roller coaster', 'horse', 'catwalk', 'glass', 'vase', 'central reservation', 'carousel', 'radiator', 'closet', 'machine', 'pier', 'fan', 'inflatable bounce game', 'pitch', 'paper', 'arcade', 'hot tub', 'helicopter', 'tray', 'partition', 'vineyard', 'bowl', 'bullring', 'flag', 'pot', 'footbridge', 'shower', 'bag', 'bulletin board', 'confessional booth', 'trunk', 'forest', 'elevator door', 'laptop', 'instrument panel', 'bucket', 'tapestry', 'platform', 'jacket', 'gate', 'monitor', 'telephone booth', 'spotlight', 'ring', 'control panel', 'blackboard', 'air conditioner', 'chest', 'clock', 'sand dune', 'pipe', 'vault', 'table football', 'cannon', 'swimming pool', 'fluorescent', 'statue', 'loudspeaker', 'exhibitor', 'ladder', 'carport', 'dam', 'pulpit', 'skylight', 'water tower', 'grill', 'display board', 'pane', 'rubbish', 'ice rink', 'fruit', 'patio', 'vending machine', 'telephone', 'net', 'backpack', 'jar', 'track', 'magazine', 'shutter', 'roof', 'banner', 'landfill', 'post', 'altarpiece', 'hat', 'arch', 'table game', 'bag', 'document', 'dome', 'pier', 'shanties', 'forecourt', 'crane', 'dog', 'piano', 'drawing', 'cabin', 'ad', 'amphitheater', 'monument', 'henhouse', 'cockpit', 'heater', 'windmill', 'pool', 'elevator', 'decoration', 'labyrinth', 'text', 'printer', 'mezzanine', 'mattress', 'straw', 'stalls', 'patio', 'billboard', 'bus stop', 'trouser', 'console table', 'rack', 'notebook', 'shrine', 'pantry', 'cart', 'steam shovel', 'porch', 'postbox', 'figurine', 'recycling bin', 'folding screen', 'telescope', 'deck chair', 'kennel', 'coffee maker', 'altar', 'fish', 'easel', 'artificial golf green', 'iceberg', 'candlestick', 'shower stall', 'television stand', 'wall socket', 'skeleton', 'grand piano', 'candy', 'grille door', 'pedestal', 'jersey', 'shoe', 'gravestone', 'shanty', 'structure', 'rocking chair', 'bird', 'place mat', 'tomb', 'big top', 'gas pump', 'lockers', 'cage', 'finger', 'bleachers', 'ferris wheel', 'hairdresser chair', 'mat', 'stands', 'aquarium', 'streetcar', 'napkin', 'dummy', 'booklet', 'sand trap', 'shop', 'table cloth', 'service station', 'coffin', 'drawer', 'cages', 'slot machine', 'balcony', 'volleyball court', 'table tennis', 'control table', 'shirt', 'merchandise', 'railway', 'parterre', 'chimney', 'can', 'tanks', 'fabric', 'alga', 'system', 'map', 'greenhouse', 'mug', 'barbecue', 'trailer', 'toilet tissue', 'organ', 'dishrag', 'island', 'keyboard', 'trench', 'basket', 'steering wheel', 'pitcher', 'goal', 'bread', 'beds', 'wood', 'file cabinet', 'newspaper', 'motorboat', 'rope', 'guitar', 'rubble', 'scarf', 'barrels', 'cap', 'leaves', 'control tower', 'dashboard', 'bandstand', 'lectern', 'switch', 'baseboard', 'shower room', 'smoke', 'faucet', 'bulldozer', 'saucepan', 'shops', 'meter', 'crevasse', 'gear', 'candelabrum', 'sofa bed', 'tunnel', 'pallet', 'wire', 'kettle', 'bidet', 'baby buggy', 'music stand', 'pipe', 'cup', 'parking meter', 'ice hockey rink', 'shelter', 'weeds', 'temple', 'patty', 'ski slope', 'panel', 'wallet', 'wheel', 'towel rack', 'roundabout', 'canister', 'rod', 'soap dispenser', 'bell', 'canvas', 'box office', 'teacup', 'trellis', 'workbench', 'valley', 'toaster', 'knife', 'podium', 'ramp', 'tumble dryer', 'fireplug', 'gym shoe', 'lab bench', 'equipment', 'rocky formation', 'plastic', 'calendar', 'caravan', 'check-in-desk', 'ticket counter', 'brush', 'mill', 'covered bridge', 'bowling alley', 'hanger', 'excavator', 'trestle', 'revolving door', 'blast furnace', 'scale', 'projector', 'soap', 'locker', 'tractor', 'stretcher', 'frame', 'grating', 'alembic', 'candle', 'barrier', 'cardboard', 'cave', 'puddle', 'tarp', 'price tag', 'watchtower', 'meters', 'light bulb', 'tracks', 'hair dryer', 'skirt', 'viaduct', 'paper towel', 'coat', 'sheet', 'fire extinguisher', 'water wheel', 'pottery', 'magazine rack', 'teapot', 'microphone', 'support', 'forklift', 'canyon', 'cash register', 'leaf', 'remote control', 'soap dish', 'windshield', 'cat', 'cue', 'vent', 'videos', 'shovel', 'eaves', 'antenna', 'shipyard', 'hen', 'traffic cone', 'washing machines', 'truck crane', 'cds', 'niche', 'scoreboard', 'briefcase', 'boot', 'sweater', 'hay', 'pack', 'bottle rack', 'glacier', 'pergola', 'building materials', 'television camera', 'first floor', 'rifle', 'tennis table', 'stadium', 'safety belt', 'cover', 'dish rack', 'synthesizer', 'pumpkin', 'gutter', 'fruit stand', 'ice floe', 'handle', 'wheelchair', 'mousepad', 'diploma', 'fairground ride', 'radio', 'hotplate', 'junk', 'wheelbarrow', 'stream', 'toll plaza', 'punching bag', 'trough', 'throne', 'chair desk', 'weighbridge', 'extractor fan', 'hanging clothes', 'dish', 'alarm clock', 'ski lift', 'chain', 'garage', 'mechanical shovel', 'wine rack', 'tramway', 'treadmill', 'menu', 'block', 'well', 'witness stand', 'branch', 'duck', 'casserole', 'frying pan', 'desk organizer', 'mast', 'spectacles', 'service elevator', 'dollhouse', 'hammock', 'clothes hanging', 'photocopier', 'notepad', 'golf cart', 'footpath', 'cross', 'baptismal font', 'boiler', 'skip', 'rotisserie', 'tables', 'water mill', 'helmet', 'cover curtain', 'brick', 'table runner', 'ashtray', 'street box', 'stick', 'hangers', 'cells', 'urinal', 'centerpiece', 'portable fridge', 'dvds', 'golf club', 'skirting board', 'water cooler', 'clipboard', 'camera', 'pigeonhole', 'chips', 'food processor', 'post box', 'lid', 'drum', 'blender', 'cave entrance', 'dental chair', 'obelisk', 'canoe', 'mobile', 'monitors', 'pool ball', 'cue rack', 'baggage carts', 'shore', 'fork', 'paper filer', 'bicycle rack', 'coat rack', 'garland', 'sports bag', 'fish tank', 'towel dispenser', 'carriage', 'brochure', 'plaque', 'stringer', 'iron', 'spoon', 'flag pole', 'toilet brush', 'book stand', 'water faucet', 'ticket office', 'broom', 'dvd', 'ice bucket', 'carapace', 'tureen', 'folders', 'chess', 'root', 'sewing machine', 'model', 'pen', 'violin', 'sweatshirt', 'recycling materials', 'mitten', 'chopping board', 'mask', 'log', 'mouse', 'grill', 'hole', 'target', 'trash bag', 'chalk', 'sticks', 'balloon', 'score', 'hair spray', 'roll', 'runner', 'engine', 'inflatable glove', 'games', 'pallets', 'baskets', 'coop', 'dvd player', 'rocking horse', 'buckets', 'bread rolls', 'shawl', 'watering can', 'spotlights', 'post-it', 'bowls', 'security camera', 'runner cloth', 'lock', 'alarm', 'side', 'roulette', 'bone', 'cutlery', 'pool balls', 'wheels', 'spice rack', 'plant pots', 'towel ring', 'bread box', 'video', 'funfair', 'breads', 'tripod', 'ironing board', 'skimmer', 'hollow', 'scratching post', 'tricycle', 'file box', 'mountain pass', 'tombstones', 'cooker', 'card game', 'golf bag', 'towel paper', 'chaise lounge', 'sun', 'toilet paper holder', 'rake', 'key', 'umbrella stand', 'dartboard', 'transformer', 'fireplace utensils', 'sweatshirts', 'cellular telephone', 'tallboy', 'stapler', 'sauna', 'test tube', 'palette', 'shopping carts', 'tools', 'push button', 'star', 'roof rack', 'barbed wire', 'spray', 'ear', 'sponge', 'racket', 'tins', 'eyeglasses', 'file', 'scarfs', 'sugar bowl', 'flip flop', 'headstones', 'laptop bag', 'leash', 'climbing frame', 'suit hanger', 'floor spotlight', 'plate rack', 'sewer', 'hard drive', 'sprinkler', 'tools box', 'necklace', 'bulbs', 'steel industry', 'club', 'jack', 'door bars', 'control panel', 'hairbrush', 'napkin holder', 'office', 'smoke detector', 'utensils', 'apron', 'scissors', 'terminal', 'grinder', 'entry phone', 'newspaper stand', 'pepper shaker', 'onions', 'central processing unit', 'tape', 'bat', 'coaster', 'calculator', 'potatoes', 'luggage rack', 'salt', 'street number', 'viewpoint', 'sword', 'cd', 'rowing machine', 'plug', 'andiron', 'pepper', 'tongs', 'bonfire', 'dog dish', 'belt', 'dumbbells', 'videocassette recorder', 'hook', 'envelopes', 'shower faucet', 'watch', 'padlock', 'swimming pool ladder', 'spanners', 'gravy boat', 'notice board', 'trash bags', 'fire alarm', 'ladle', 'stethoscope', 'rocket', 'funnel', 'bowling pins', 'valve', 'thermometer', 'cups', 'spice jar', 'night light', 'soaps', 'games table', 'slotted spoon', 'reel', 'scourer', 'sleeping robe', 'desk mat', 'dumbbell', 'hammer', 'tie', 'typewriter', 'shaker', 'cheese dish', 'sea star', 'racquet', 'butane gas cylinder', 'paper weight', 'shaving brush', 'sunglasses', 'gear shift', 'towel rail', 'adding machine'] SUN_RGBD_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag'] SCAN_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag'] SCAN_40 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop'] SCAN_20 = ["wall", "floor", "cabinet", "bed", "chair", "sofa", "table", "door", "window", "bookshelf", "picture", "counter", "desk", "curtain", "refrigerator", "shower curtain", "toilet", "sink", "bathtub", "otherfurniture"] CITYSCAPES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] CITYSCAPES_THING = ["person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"] BDD_SEM = ["road", "sidewalk", "building", "wall", "fence", "pole", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"] BDD_PANO = ['dynamic', 'ego vehicle', 'ground', 'static', 'parking', 'rail track', 'road', 'sidewalk', 'bridge', 'building', 'fence', 'garage', 'guard rail', 'tunnel', 'wall', 'banner', 'billboard', 'lane divider', 'parking sign', 'pole', 'polegroup', 'street light', 'traffic cone', 'traffic device', 'traffic light', 'traffic sign', 'traffic sign frame', 'terrain', 'vegetation', 'sky', 'person', 'rider', 'bicycle', 'bus', 'car', 'caravan', 'motorcycle', 'trailer', 'train', 'truck'] IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] IMAGENET_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] IMAGENET_DEFAULT_TEMPLATES = [ '{}.', 'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', ] IMAGENET_SIMPLE_TEMPLATES = [ 'a photo of {}.', ] ================================================ FILE: utils/dist.py ================================================ import functools import io import os import random import subprocess import time from collections import OrderedDict, defaultdict, deque import datetime import pickle from typing import Optional, List import json, time import numpy as np import torch import torch.distributed as dist from torch import Tensor import colorsys def init_distributed_mode(args): if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) # launch by torch.distributed.launch # Single node # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... # Multi nodes # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) # local_world_size = int(os.environ['GPU_PER_NODE_COUNT']) # args.world_size = args.world_size * local_world_size # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) # args.rank = args.rank * local_world_size + args.local_rank print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) print(json.dumps(dict(os.environ), indent=2)) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) args.world_size = int(os.environ['SLURM_NPROCS']) if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1': pass else: import util.hostlist as uh nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST']) gpu_ids = [int(node[3:]) for node in nodenames] fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0)) # fixid += random.randint(0, 300) port = str(3137 + int(min(gpu_ids)) + fixid) args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port) print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) else: print('Not using distributed mode') args.distributed = False args.world_size = 1 args.rank = 0 args.local_rank = 0 return print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) args.distributed = True torch.cuda.set_device(args.local_rank) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) torch.distributed.init_process_group( backend=args.dist_backend, world_size=args.world_size, rank=args.rank, init_method=args.dist_url, ) print("Before torch.distributed.barrier()") torch.distributed.barrier() print("End torch.distributed.barrier()") ================================================ FILE: utils/distributed.py ================================================ # import os # import time # import torch # import pickle # import subprocess # # from mpi4py import MPI # import torch.distributed as dist # # # def apply_distributed(opt): # if opt['rank'] == 0: # hostname_cmd = ["hostname -I"] # result = subprocess.check_output(hostname_cmd, shell=True) # master_address = result.decode('utf-8').split()[0] # master_port = opt['PORT'] # else: # master_address = None # master_port = None # # master_address = MPI.COMM_WORLD.bcast(master_address, root=0) # master_port = MPI.COMM_WORLD.bcast(master_port, root=0) # # if torch.distributed.is_available() and opt['world_size'] > 1: # init_method_url = 'tcp://{}:{}'.format(master_address, master_port) # backend = 'nccl' # world_size = opt['world_size'] # rank = opt['rank'] # torch.distributed.init_process_group(backend=backend, # init_method=init_method_url, # world_size=world_size, # rank=rank) # # def init_distributed(opt): # opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() # if 'OMPI_COMM_WORLD_SIZE' not in os.environ: # # application was started without MPI # # default to single node with single process # opt['env_info'] = 'no MPI' # opt['world_size'] = 1 # opt['local_size'] = 1 # opt['rank'] = 0 # opt['local_rank'] = 0 # opt['master_address'] = '127.0.0.1' # opt['master_port'] = '8673' # else: # # application was started with MPI # # get MPI parameters # opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) # opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) # opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) # opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) # # # set up device # if not opt['CUDA']: # assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' # opt['device'] = torch.device("cpu") # else: # torch.cuda.set_device(opt['local_rank']) # opt['device'] = torch.device("cuda", opt['local_rank']) # # apply_distributed(opt) # return opt # # def is_main_process(): # rank = 0 # if 'OMPI_COMM_WORLD_SIZE' in os.environ: # rank = int(os.environ['OMPI_COMM_WORLD_RANK']) # # return rank == 0 # # def get_world_size(): # if not dist.is_available(): # return 1 # if not dist.is_initialized(): # return 1 # return dist.get_world_size() # # def get_rank(): # if not dist.is_available(): # return 0 # if not dist.is_initialized(): # return 0 # return dist.get_rank() # # # def synchronize(): # """ # Helper function to synchronize (barrier) among all processes when # using distributed training # """ # if not dist.is_available(): # return # if not dist.is_initialized(): # return # world_size = dist.get_world_size() # rank = dist.get_rank() # if world_size == 1: # return # # def _send_and_wait(r): # if rank == r: # tensor = torch.tensor(0, device="cuda") # else: # tensor = torch.tensor(1, device="cuda") # dist.broadcast(tensor, r) # while tensor.item() == 1: # time.sleep(1) # # _send_and_wait(0) # # now sync on the main process # _send_and_wait(1) ================================================ FILE: utils/misc.py ================================================ # -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import math # HACK for evalution def hook_metadata(metadata, name): if name == 'cityscapes_fine_sem_seg_val': metadata.__setattr__("keep_sem_bgd", False) return metadata def hook_opt(model, name): if name in ['cityscapes_fine_panoptic_val', 'ade20k_panoptic_val', 'bdd10k_40_panoptic_val', 'cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val']: model.model.object_mask_threshold = 0.4 else: model.model.object_mask_threshold = 0.8 # HACK for evalution def hook_switcher(model, name): mappings = {} if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']: mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False} elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name: mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False} elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']: mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True} elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']: mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True} else: if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd"]: assert False, "dataset switcher is not defined" for key, value in mappings.items(): if key == 'SEMANTIC_ON': model.model.semantic_on = value if key == 'INSTANCE_ON': model.model.instance_on = value if key == 'PANOPTIC_ON': model.model.panoptic_on = value class AverageMeter(object): """Computes and stores the average and current value.""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1, decay=0): self.val = val if decay: alpha = math.exp(-n / decay) # exponential decay over 100 updates self.sum = alpha * self.sum + (1 - alpha) * val * n self.count = alpha * self.count + (1 - alpha) * n else: self.sum += val * n self.count += n self.avg = self.sum / self.count ================================================ FILE: utils/model.py ================================================ import logging import os import time import pickle import torch # from utils.distributed import is_main_process logger = logging.getLogger(__name__) NORM_MODULES = [ torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ] def register_norm_module(cls): NORM_MODULES.append(cls) return cls def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): model_keys = sorted(model_state_dict.keys()) ckpt_keys = sorted(ckpt_state_dict.keys()) result_dicts = {} matched_log = [] unmatched_log = [] unloaded_log = [] for model_key in model_keys: model_weight = model_state_dict[model_key] if model_key in ckpt_keys: ckpt_weight = ckpt_state_dict[model_key] if model_weight.shape == ckpt_weight.shape: result_dicts[model_key] = ckpt_weight ckpt_keys.pop(ckpt_keys.index(model_key)) matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) # if is_main_process(): # for info in matched_log: # logger.info(info) # for info in unloaded_log: # logger.warning(info) # for key in ckpt_keys: # logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) # for info in unmatched_log: # logger.warning(info) return result_dicts ================================================ FILE: utils/nms.py ================================================ import torch def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None,thresh=0.7): n_samples = len(cate_labels) if n_samples == 0: return [] if sum_masks is None: sum_masks = seg_masks.sum((1, 2)).float() seg_masks = seg_masks.reshape(n_samples, -1).float() # inter. inter_matrix = torch.mm(seg_masks.float(), seg_masks.float().transpose(1, 0)) # union. sum_masks_x = sum_masks.expand(n_samples, n_samples) # iou. iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix+1e-5)).triu(diagonal=1) result_idx=[] for i in range(len(iou_matrix)): if max(iou_matrix[:,i])= len(iou_matrix) - num_gt: gt_idx.append(k) k += 1 else: if i >= len(iou_matrix) - num_gt: gt_idx.append(idx_map[int(iou_matrix[:, i].max(0)[1])]) iou_matrix[:, i]=0.0 iou_matrix[i, :]=0.0 return result_idx,gt_idx def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1, score_factors=None): """NMS for multi-class bboxes. Args: multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) multi_scores (Tensor): shape (n, #class), where the 0th column contains scores of the background class, but this will be ignored. score_thr (float): bbox threshold, bboxes with scores lower than it will not be considered. nms_thr (float): NMS IoU threshold max_num (int): if there are more than max_num bboxes after NMS, only top max_num will be kept. score_factors (Tensor): The factors multiplied to scores before applying NMS Returns: tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels are 0-based. """ num_classes = multi_scores.shape[1] bboxes, labels = [], [] nms_cfg_ = nms_cfg.copy() nms_type = nms_cfg_.pop('type', 'nms') nms_op = getattr(nms_wrapper, nms_type) for i in range(1, num_classes): cls_inds = multi_scores[:, i] > score_thr if not cls_inds.any(): continue # get bboxes and scores of this class if multi_bboxes.shape[1] == 4: _bboxes = multi_bboxes[cls_inds, :] else: _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4] _scores = multi_scores[cls_inds, i] if score_factors is not None: _scores *= score_factors[cls_inds] cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1) cls_dets, _ = nms_op(cls_dets, **nms_cfg_) cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ), i - 1, dtype=torch.long) bboxes.append(cls_dets) labels.append(cls_labels) if bboxes: bboxes = torch.cat(bboxes) labels = torch.cat(labels) if bboxes.shape[0] > max_num: _, inds = bboxes[:, -1].sort(descending=True) inds = inds[:max_num] bboxes = bboxes[inds] labels = labels[inds] else: bboxes = multi_bboxes.new_zeros((0, 5)) labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) return bboxes, labels ================================================ FILE: utils/prompt_engineering.py ================================================ import numpy as np def get_prompt_templates(): prompt_templates = [ '{}.', 'a photo of a {}.', 'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', ] return prompt_templates def prompt_engineering(classnames, topk=1, suffix='.'): prompt_templates = get_prompt_templates() temp_idx = np.random.randint(min(len(prompt_templates), topk)) if isinstance(classnames, list): classname = random.choice(classnames) else: classname = classnames return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) ================================================ FILE: utils/utils.py ================================================ import torch import numpy as np def slprint(x, name='x'): if isinstance(x, (torch.Tensor, np.ndarray)): print(f'{name}.shape:', x.shape) elif isinstance(x, (tuple, list)): print('type x:', type(x)) for i in range(min(10, len(x))): slprint(x[i], f'{name}[{i}]') elif isinstance(x, dict): for k,v in x.items(): slprint(v, f'{name}[{k}]') else: print(f'{name}.type:', type(x)) ================================================ FILE: utils/visualizer.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import colorsys import logging import math import numpy as np from enum import Enum, unique import cv2 import matplotlib as mpl import matplotlib.colors as mplc import matplotlib.figure as mplfigure import pycocotools.mask as mask_util import torch from matplotlib.backends.backend_agg import FigureCanvasAgg from PIL import Image from detectron2.data import MetadataCatalog from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes from detectron2.utils.file_io import PathManager from detectron2.utils.colormap import random_color logger = logging.getLogger(__name__) __all__ = ["ColorMode", "VisImage", "Visualizer"] _SMALL_OBJECT_AREA_THRESH = 1000 _LARGE_MASK_AREA_THRESH = 120000 _OFF_WHITE = (1.0, 1.0, 240.0 / 255) _BLACK = (0, 0, 0) _RED = (1.0, 0, 0) _KEYPOINT_THRESHOLD = 0.05 @unique class ColorMode(Enum): """ Enum of different color modes to use for instance visualizations. """ IMAGE = 0 """ Picks a random color for every instance and overlay segmentations with low opacity. """ SEGMENTATION = 1 """ Let instances of the same category have similar colors (from metadata.thing_colors), and overlay them with high opacity. This provides more attention on the quality of segmentation. """ IMAGE_BW = 2 """ Same as IMAGE, but convert all areas without masks to gray-scale. Only available for drawing per-instance mask predictions. """ class GenericMask: """ Attribute: polygons (list[ndarray]): list[ndarray]: polygons for this mask. Each ndarray has format [x, y, x, y, ...] mask (ndarray): a binary mask """ def __init__(self, mask_or_polygons, height, width): self._mask = self._polygons = self._has_holes = None self.height = height self.width = width m = mask_or_polygons if isinstance(m, dict): # RLEs assert "counts" in m and "size" in m if isinstance(m["counts"], list): # uncompressed RLEs h, w = m["size"] assert h == height and w == width m = mask_util.frPyObjects(m, h, w) self._mask = mask_util.decode(m)[:, :] return if isinstance(m, list): # list[ndarray] self._polygons = [np.asarray(x).reshape(-1) for x in m] return if isinstance(m, np.ndarray): # assumed to be a binary mask assert m.shape[1] != 2, m.shape assert m.shape == ( height, width, ), f"mask shape: {m.shape}, target dims: {height}, {width}" self._mask = m.astype("uint8") return raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) @property def mask(self): if self._mask is None: self._mask = self.polygons_to_mask(self._polygons) return self._mask @property def polygons(self): if self._polygons is None: self._polygons, self._has_holes = self.mask_to_polygons(self._mask) return self._polygons @property def has_holes(self): if self._has_holes is None: if self._mask is not None: self._polygons, self._has_holes = self.mask_to_polygons(self._mask) else: self._has_holes = False # if original format is polygon, does not have holes return self._has_holes def mask_to_polygons(self, mask): # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. # Internal contours (holes) are placed in hierarchy-2. # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) hierarchy = res[-1] if hierarchy is None: # empty mask return [], False has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 res = res[-2] res = [x.flatten() for x in res] # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. # We add 0.5 to turn them into real-value coordinate space. A better solution # would be to first +0.5 and then dilate the returned polygon by 0.5. res = [x + 0.5 for x in res if len(x) >= 6] return res, has_holes def polygons_to_mask(self, polygons): rle = mask_util.frPyObjects(polygons, self.height, self.width) rle = mask_util.merge(rle) return mask_util.decode(rle)[:, :] def area(self): return self.mask.sum() def bbox(self): p = mask_util.frPyObjects(self.polygons, self.height, self.width) p = mask_util.merge(p) bbox = mask_util.toBbox(p) bbox[2] += bbox[0] bbox[3] += bbox[1] return bbox class _PanopticPrediction: """ Unify different panoptic annotation/prediction formats """ def __init__(self, panoptic_seg, segments_info, metadata=None): if segments_info is None: assert metadata is not None # If "segments_info" is None, we assume "panoptic_img" is a # H*W int32 image storing the panoptic_id in the format of # category_id * label_divisor + instance_id. We reserve -1 for # VOID label. label_divisor = metadata.label_divisor segments_info = [] for panoptic_label in np.unique(panoptic_seg.numpy()): if panoptic_label == -1: # VOID region. continue pred_class = panoptic_label // label_divisor isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() segments_info.append( { "id": int(panoptic_label), "category_id": int(pred_class), "isthing": bool(isthing), } ) del metadata self._seg = panoptic_seg self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) areas = areas.numpy() sorted_idxs = np.argsort(-areas) self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] self._seg_ids = self._seg_ids.tolist() for sid, area in zip(self._seg_ids, self._seg_areas): if sid in self._sinfo: self._sinfo[sid]["area"] = float(area) def non_empty_mask(self): """ Returns: (H, W) array, a mask for all pixels that have a prediction """ empty_ids = [] for id in self._seg_ids: if id not in self._sinfo: empty_ids.append(id) if len(empty_ids) == 0: return np.zeros(self._seg.shape, dtype=np.uint8) assert ( len(empty_ids) == 1 ), ">1 ids corresponds to no labels. This is currently not supported" return (self._seg != empty_ids[0]).numpy().astype(np.bool) def semantic_masks(self): for sid in self._seg_ids: sinfo = self._sinfo.get(sid) if sinfo is None or sinfo["isthing"]: # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. continue yield (self._seg == sid).numpy().astype(np.bool), sinfo def instance_masks(self): for sid in self._seg_ids: sinfo = self._sinfo.get(sid) if sinfo is None or not sinfo["isthing"]: continue mask = (self._seg == sid).numpy().astype(np.bool) if mask.sum() > 0: yield mask, sinfo def _create_text_labels(classes, scores, class_names, is_crowd=None): """ Args: classes (list[int] or None): scores (list[float] or None): class_names (list[str] or None): is_crowd (list[bool] or None): Returns: list[str] or None """ labels = None if classes is not None: if class_names is not None and len(class_names) > 0: labels = [class_names[i] for i in classes] else: labels = [str(i) for i in classes] if scores is not None: if labels is None: labels = ["{:.0f}%".format(s * 100) for s in scores] else: labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] if labels is not None and is_crowd is not None: labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] return labels class VisImage: def __init__(self, img, scale=1.0): """ Args: img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. scale (float): scale the input image """ self.img = img self.scale = scale self.width, self.height = img.shape[1], img.shape[0] self._setup_figure(img) def _setup_figure(self, img): """ Args: Same as in :meth:`__init__()`. Returns: fig (matplotlib.pyplot.figure): top level container for all the image plot elements. ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. """ fig = mplfigure.Figure(frameon=False) self.dpi = fig.get_dpi() # add a small 1e-2 to avoid precision lost due to matplotlib's truncation # (https://github.com/matplotlib/matplotlib/issues/15363) fig.set_size_inches( (self.width * self.scale + 1e-2) / self.dpi, (self.height * self.scale + 1e-2) / self.dpi, ) self.canvas = FigureCanvasAgg(fig) # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) ax.axis("off") self.fig = fig self.ax = ax self.reset_image(img) def reset_image(self, img): """ Args: img: same as in __init__ """ img = img.astype("uint8") self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") def save(self, filepath): """ Args: filepath (str): a string that contains the absolute path, including the file name, where the visualized image will be saved. """ self.fig.savefig(filepath) def get_image(self): """ Returns: ndarray: the visualized image of shape (H, W, 3) (RGB) in uint8 type. The shape is scaled w.r.t the input image using the given `scale` argument. """ canvas = self.canvas s, (width, height) = canvas.print_to_buffer() # buf = io.BytesIO() # works for cairo backend # canvas.print_rgba(buf) # width, height = self.width, self.height # s = buf.getvalue() buffer = np.frombuffer(s, dtype="uint8") img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype("uint8") class Visualizer: """ Visualizer that draws data about detection/segmentation on images. It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` that draw primitive objects to images, as well as high-level wrappers like `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` that draw composite data in some pre-defined style. Note that the exact visualization style for the high-level wrappers are subject to change. Style such as color, opacity, label contents, visibility of labels, or even the visibility of objects themselves (e.g. when the object is too small) may change according to different heuristics, as long as the results still look visually reasonable. To obtain a consistent style, you can implement custom drawing functions with the abovementioned primitive methods instead. If you need more customized visualization styles, you can process the data yourself following their format documented in tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not intend to satisfy everyone's preference on drawing styles. This visualizer focuses on high rendering quality rather than performance. It is not designed to be used for real-time applications. """ # TODO implement a fast, rasterized version using OpenCV def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): """ Args: img_rgb: a numpy array of shape (H, W, C), where H and W correspond to the height and width of the image respectively. C is the number of color channels. The image is required to be in RGB format since that is a requirement of the Matplotlib library. The image is also expected to be in the range [0, 255]. metadata (Metadata): dataset metadata (e.g. class names and colors) instance_mode (ColorMode): defines one of the pre-defined style for drawing instances on an image. """ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) if metadata is None: metadata = MetadataCatalog.get("__nonexist__") self.metadata = metadata self.output = VisImage(self.img, scale=scale) self.cpu_device = torch.device("cpu") # too small texts are useless, therefore clamp to 9 self._default_font_size = max( np.sqrt(self.output.height * self.output.width) // 90, 10 // scale ) self._default_font_size = 18 self._instance_mode = instance_mode self.keypoint_threshold = _KEYPOINT_THRESHOLD def draw_instance_predictions(self, predictions): """ Draw instance-level prediction results on an image. Args: predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). Returns: output (VisImage): image object with visualizations. """ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None scores = predictions.scores if predictions.has("scores") else None classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None keep = (scores > 0.5).cpu() boxes = boxes[keep] scores = scores[keep] classes = np.array(classes) classes = classes[np.array(keep)] labels = np.array(labels) labels = labels[np.array(keep)] if predictions.has("pred_masks"): masks = np.asarray(predictions.pred_masks) masks = masks[np.array(keep)] masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] else: masks = None if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): # if self.metadata.get("thing_colors"): colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes ] alpha = 0.4 else: colors = None alpha = 0.4 if self._instance_mode == ColorMode.IMAGE_BW: self.output.reset_image( self._create_grayscale_image( (predictions.pred_masks.any(dim=0) > 0).numpy() if predictions.has("pred_masks") else None ) ) alpha = 0.3 self.overlay_instances( masks=masks, boxes=boxes, labels=labels, keypoints=keypoints, assigned_colors=colors, alpha=alpha, ) return self.output def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7): """ Draw semantic segmentation predictions/labels. Args: sem_seg (Tensor or ndarray): the segmentation of shape (H, W). Each value is the integer label of the pixel. area_threshold (int): segments with less than `area_threshold` are not drawn. alpha (float): the larger it is, the more opaque the segmentations are. Returns: output (VisImage): image object with visualizations. """ if isinstance(sem_seg, torch.Tensor): sem_seg = sem_seg.numpy() labels, areas = np.unique(sem_seg, return_counts=True) sorted_idxs = np.argsort(-areas).tolist() labels = labels[sorted_idxs] for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): try: mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] except (AttributeError, IndexError): mask_color = None binary_mask = (sem_seg == label).astype(np.uint8) text = self.metadata.stuff_classes[label] self.draw_binary_mask( binary_mask, color=mask_color, edge_color=_OFF_WHITE, text=text, alpha=alpha, area_threshold=area_threshold, ) return self.output def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): """ Draw panoptic prediction annotations or results. Args: panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. If it is a ``list[dict]``, each dict contains keys "id", "category_id". If None, category id of each pixel is computed by ``pixel // metadata.label_divisor``. area_threshold (int): stuff segments with less than `area_threshold` are not drawn. Returns: output (VisImage): image object with visualizations. """ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) if self._instance_mode == ColorMode.IMAGE_BW: self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) # draw mask for all semantic segments first i.e. "stuff" for mask, sinfo in pred.semantic_masks(): category_idx = sinfo["category_id"] try: mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] except AttributeError: mask_color = None text = self.metadata.stuff_classes[category_idx] self.draw_binary_mask( mask, color=mask_color, edge_color=_OFF_WHITE, text=text, alpha=alpha, area_threshold=area_threshold, ) # draw mask for all instances second all_instances = list(pred.instance_masks()) if len(all_instances) == 0: return self.output masks, sinfo = list(zip(*all_instances)) category_ids = [x["category_id"] for x in sinfo] try: scores = [x["score"] for x in sinfo] except KeyError: scores = None labels = _create_text_labels( category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo] ) try: colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids ] except AttributeError: colors = None self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) return self.output draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility def draw_dataset_dict(self, dic): """ Draw annotations/segmentaions in Detectron2 Dataset format. Args: dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. Returns: output (VisImage): image object with visualizations. """ annos = dic.get("annotations", None) if annos: if "segmentation" in annos[0]: masks = [x["segmentation"] for x in annos] else: masks = None if "keypoints" in annos[0]: keypts = [x["keypoints"] for x in annos] keypts = np.array(keypts).reshape(len(annos), -1, 3) else: keypts = None boxes = [ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) if len(x["bbox"]) == 4 else x["bbox"] for x in annos ] colors = None category_ids = [x["category_id"] for x in annos] if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids ] names = self.metadata.get("thing_classes", None) labels = _create_text_labels( category_ids, scores=None, class_names=names, is_crowd=[x.get("iscrowd", 0) for x in annos], ) self.overlay_instances( labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors ) sem_seg = dic.get("sem_seg", None) if sem_seg is None and "sem_seg_file_name" in dic: with PathManager.open(dic["sem_seg_file_name"], "rb") as f: sem_seg = Image.open(f) sem_seg = np.asarray(sem_seg, dtype="uint8") if sem_seg is not None: self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4) pan_seg = dic.get("pan_seg", None) if pan_seg is None and "pan_seg_file_name" in dic: with PathManager.open(dic["pan_seg_file_name"], "rb") as f: pan_seg = Image.open(f) pan_seg = np.asarray(pan_seg) from panopticapi.utils import rgb2id pan_seg = rgb2id(pan_seg) if pan_seg is not None: segments_info = dic["segments_info"] pan_seg = torch.tensor(pan_seg) self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7) return self.output def overlay_instances( self, *, boxes=None, labels=None, masks=None, keypoints=None, assigned_colors=None, alpha=0.5, ): """ Args: boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, or a :class:`RotatedBoxes`, or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image, labels (list[str]): the text to be displayed for each instance. masks (masks-like object): Supported types are: * :class:`detectron2.structures.PolygonMasks`, :class:`detectron2.structures.BitMasks`. * list[list[ndarray]]: contains the segmentation masks for all objects in one image. The first level of the list corresponds to individual instances. The second level to all the polygon that compose the instance, and the third level to the polygon coordinates. The third level should have the format of [x0, y0, x1, y1, ..., xn, yn] (n >= 3). * list[ndarray]: each ndarray is a binary mask of shape (H, W). * list[dict]: each dict is a COCO-style RLE. keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), where the N is the number of instances and K is the number of keypoints. The last dimension corresponds to (x, y, visibility or score). assigned_colors (list[matplotlib.colors]): a list of colors, where each color corresponds to each mask or box in the image. Refer to 'matplotlib.colors' for full list of formats that the colors are accepted in. Returns: output (VisImage): image object with visualizations. """ num_instances = 0 if boxes is not None: boxes = self._convert_boxes(boxes) num_instances = len(boxes) if masks is not None: masks = self._convert_masks(masks) if num_instances: assert len(masks) == num_instances else: num_instances = len(masks) if keypoints is not None: if num_instances: assert len(keypoints) == num_instances else: num_instances = len(keypoints) keypoints = self._convert_keypoints(keypoints) if labels is not None: assert len(labels) == num_instances if assigned_colors is None: assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] if num_instances == 0: return self.output if boxes is not None and boxes.shape[1] == 5: return self.overlay_rotated_instances( boxes=boxes, labels=labels, assigned_colors=assigned_colors ) # Display in largest to smallest order to reduce occlusion. areas = None if boxes is not None: areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) elif masks is not None: areas = np.asarray([x.area() for x in masks]) if areas is not None: sorted_idxs = np.argsort(-areas).tolist() # Re-order overlapped instances in descending order. boxes = boxes[sorted_idxs] if boxes is not None else None labels = [labels[k] for k in sorted_idxs] if labels is not None else None masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] keypoints = keypoints[sorted_idxs] if keypoints is not None else None for i in range(num_instances): color = assigned_colors[i] if boxes is not None: self.draw_box(boxes[i], edge_color=color) if masks is not None: for segment in masks[i].polygons: self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) if labels is not None: # first get a box if boxes is not None: x0, y0, x1, y1 = boxes[i] text_pos = (x0, y0) # if drawing boxes, put text on the box corner. horiz_align = "left" elif masks is not None: # skip small mask without polygon if len(masks[i].polygons) == 0: continue x0, y0, x1, y1 = masks[i].bbox() # draw text in the center (defined by median) when box is not drawn # median is less sensitive to outliers. text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] horiz_align = "center" else: continue # drawing the box confidence for keypoints isn't very useful. # for small objects, draw text at the side to avoid occlusion instance_area = (y1 - y0) * (x1 - x0) if ( instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale or y1 - y0 < 40 * self.output.scale ): if y1 >= self.output.height - 5: text_pos = (x1, y0) else: text_pos = (x0, y1) height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) lighter_color = self._change_color_brightness(color, brightness_factor=0.7) font_size = ( np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size ) self.draw_text( labels[i], text_pos, color=lighter_color, horizontal_alignment=horiz_align, font_size=font_size, ) # draw keypoints if keypoints is not None: for keypoints_per_instance in keypoints: self.draw_and_connect_keypoints(keypoints_per_instance) return self.output def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): """ Args: boxes (ndarray): an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image. labels (list[str]): the text to be displayed for each instance. assigned_colors (list[matplotlib.colors]): a list of colors, where each color corresponds to each mask or box in the image. Refer to 'matplotlib.colors' for full list of formats that the colors are accepted in. Returns: output (VisImage): image object with visualizations. """ num_instances = len(boxes) if assigned_colors is None: assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] if num_instances == 0: return self.output # Display in largest to smallest order to reduce occlusion. if boxes is not None: areas = boxes[:, 2] * boxes[:, 3] sorted_idxs = np.argsort(-areas).tolist() # Re-order overlapped instances in descending order. boxes = boxes[sorted_idxs] labels = [labels[k] for k in sorted_idxs] if labels is not None else None colors = [assigned_colors[idx] for idx in sorted_idxs] for i in range(num_instances): self.draw_rotated_box_with_label( boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None ) return self.output def draw_and_connect_keypoints(self, keypoints): """ Draws keypoints of an instance and follows the rules for keypoint connections to draw lines between appropriate keypoints. This follows color heuristics for line color. Args: keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints and the last dimension corresponds to (x, y, probability). Returns: output (VisImage): image object with visualizations. """ visible = {} keypoint_names = self.metadata.get("keypoint_names") for idx, keypoint in enumerate(keypoints): # draw keypoint x, y, prob = keypoint if prob > self.keypoint_threshold: self.draw_circle((x, y), color=_RED) if keypoint_names: keypoint_name = keypoint_names[idx] visible[keypoint_name] = (x, y) if self.metadata.get("keypoint_connection_rules"): for kp0, kp1, color in self.metadata.keypoint_connection_rules: if kp0 in visible and kp1 in visible: x0, y0 = visible[kp0] x1, y1 = visible[kp1] color = tuple(x / 255.0 for x in color) self.draw_line([x0, x1], [y0, y1], color=color) # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip # Note that this strategy is specific to person keypoints. # For other keypoints, it should just do nothing try: ls_x, ls_y = visible["left_shoulder"] rs_x, rs_y = visible["right_shoulder"] mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 except KeyError: pass else: # draw line from nose to mid-shoulder nose_x, nose_y = visible.get("nose", (None, None)) if nose_x is not None: self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) try: # draw line from mid-shoulder to mid-hip lh_x, lh_y = visible["left_hip"] rh_x, rh_y = visible["right_hip"] except KeyError: pass else: mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) return self.output """ Primitive drawing functions: """ def draw_text( self, text, position, *, font_size=None, color="g", horizontal_alignment="center", rotation=0, ): """ Args: text (str): class label position (tuple): a tuple of the x and y coordinates to place text on image. font_size (int, optional): font of the text. If not provided, a font size proportional to the image width is calculated and used. color: color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. horizontal_alignment (str): see `matplotlib.text.Text` rotation: rotation angle in degrees CCW Returns: output (VisImage): image object with text drawn. """ if not font_size: font_size = self._default_font_size # since the text background is dark, we don't want the text to be dark color = np.maximum(list(mplc.to_rgb(color)), 0.2) color[np.argmax(color)] = max(0.8, np.max(color)) x, y = position self.output.ax.text( x, y, text, size=font_size * self.output.scale, family="sans-serif", bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, verticalalignment="top", horizontalalignment=horizontal_alignment, color=color, zorder=10, rotation=rotation, ) return self.output def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): """ Args: box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 are the coordinates of the image's top left corner. x1 and y1 are the coordinates of the image's bottom right corner. alpha (float): blending efficient. Smaller values lead to more transparent masks. edge_color: color of the outline of the box. Refer to `matplotlib.colors` for full list of formats that are accepted. line_style (string): the string to use to create the outline of the boxes. Returns: output (VisImage): image object with box drawn. """ x0, y0, x1, y1 = box_coord width = x1 - x0 height = y1 - y0 linewidth = max(self._default_font_size / 4, 1) self.output.ax.add_patch( mpl.patches.Rectangle( (x0, y0), width, height, fill=False, edgecolor=edge_color, linewidth=linewidth * self.output.scale, alpha=alpha, linestyle=line_style, ) ) return self.output def draw_rotated_box_with_label( self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None ): """ Draw a rotated box with label on its top-left corner. Args: rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), where cnt_x and cnt_y are the center coordinates of the box. w and h are the width and height of the box. angle represents how many degrees the box is rotated CCW with regard to the 0-degree box. alpha (float): blending efficient. Smaller values lead to more transparent masks. edge_color: color of the outline of the box. Refer to `matplotlib.colors` for full list of formats that are accepted. line_style (string): the string to use to create the outline of the boxes. label (string): label for rotated box. It will not be rendered when set to None. Returns: output (VisImage): image object with box drawn. """ cnt_x, cnt_y, w, h, angle = rotated_box area = w * h # use thinner lines when the box is small linewidth = self._default_font_size / ( 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 ) theta = angle * math.pi / 180.0 c = math.cos(theta) s = math.sin(theta) rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] # x: left->right ; y: top->down rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] for k in range(4): j = (k + 1) % 4 self.draw_line( [rotated_rect[k][0], rotated_rect[j][0]], [rotated_rect[k][1], rotated_rect[j][1]], color=edge_color, linestyle="--" if k == 1 else line_style, linewidth=linewidth, ) if label is not None: text_pos = rotated_rect[1] # topleft corner height_ratio = h / np.sqrt(self.output.height * self.output.width) label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) font_size = ( np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size ) self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) return self.output def draw_circle(self, circle_coord, color, radius=3): """ Args: circle_coord (list(int) or tuple(int)): contains the x and y coordinates of the center of the circle. color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. radius (int): radius of the circle. Returns: output (VisImage): image object with box drawn. """ x, y = circle_coord self.output.ax.add_patch( mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) ) return self.output def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): """ Args: x_data (list[int]): a list containing x values of all the points being drawn. Length of list should match the length of y_data. y_data (list[int]): a list containing y values of all the points being drawn. Length of list should match the length of x_data. color: color of the line. Refer to `matplotlib.colors` for a full list of formats that are accepted. linestyle: style of the line. Refer to `matplotlib.lines.Line2D` for a full list of formats that are accepted. linewidth (float or None): width of the line. When it's None, a default value will be computed and used. Returns: output (VisImage): image object with line drawn. """ if linewidth is None: linewidth = self._default_font_size / 3 linewidth = max(linewidth, 1) self.output.ax.add_line( mpl.lines.Line2D( x_data, y_data, linewidth=linewidth * self.output.scale, color=color, linestyle=linestyle, ) ) return self.output def draw_binary_mask( self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10 ): """ Args: binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and W is the image width. Each value in the array is either a 0 or 1 value of uint8 type. color: color of the mask. Refer to `matplotlib.colors` for a full list of formats that are accepted. If None, will pick a random color. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a full list of formats that are accepted. text (str): if None, will be drawn on the object alpha (float): blending efficient. Smaller values lead to more transparent masks. area_threshold (float): a connected component smaller than this area will not be shown. Returns: output (VisImage): image object with mask drawn. """ if color is None: color = random_color(rgb=True, maximum=1) color = mplc.to_rgb(color) has_valid_segment = False binary_mask = binary_mask.astype("uint8") # opencv needs uint8 mask = GenericMask(binary_mask, self.output.height, self.output.width) shape2d = (binary_mask.shape[0], binary_mask.shape[1]) if not mask.has_holes: # draw polygons for regular masks for segment in mask.polygons: area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) if area < (area_threshold or 0): continue has_valid_segment = True segment = segment.reshape(-1, 2) self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) else: # TODO: Use Path/PathPatch to draw vector graphics: # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon rgba = np.zeros(shape2d + (4,), dtype="float32") rgba[:, :, :3] = color rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha has_valid_segment = True self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) if text is not None and has_valid_segment: lighter_color = self._change_color_brightness(color, brightness_factor=0.7) self._draw_text_in_mask(binary_mask, text, lighter_color) return self.output def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): """ Args: soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. color: color of the mask. Refer to `matplotlib.colors` for a full list of formats that are accepted. If None, will pick a random color. text (str): if None, will be drawn on the object alpha (float): blending efficient. Smaller values lead to more transparent masks. Returns: output (VisImage): image object with mask drawn. """ if color is None: color = random_color(rgb=True, maximum=1) color = mplc.to_rgb(color) shape2d = (soft_mask.shape[0], soft_mask.shape[1]) rgba = np.zeros(shape2d + (4,), dtype="float32") rgba[:, :, :3] = color rgba[:, :, 3] = soft_mask * alpha self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) if text is not None: lighter_color = self._change_color_brightness(color, brightness_factor=0.7) binary_mask = (soft_mask > 0.5).astype("uint8") self._draw_text_in_mask(binary_mask, text, lighter_color) return self.output def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): """ Args: segment: numpy array of shape Nx2, containing all the points in the polygon. color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a full list of formats that are accepted. If not provided, a darker shade of the polygon color will be used instead. alpha (float): blending efficient. Smaller values lead to more transparent masks. Returns: output (VisImage): image object with polygon drawn. """ if edge_color is None: # make edge color darker than the polygon color if alpha > 0.8: edge_color = self._change_color_brightness(color, brightness_factor=-0.7) else: edge_color = color edge_color = mplc.to_rgb(edge_color) + (1,) polygon = mpl.patches.Polygon( segment, fill=True, facecolor=mplc.to_rgb(color) + (alpha,), edgecolor=edge_color, linewidth=max(self._default_font_size // 15 * self.output.scale, 1), ) self.output.ax.add_patch(polygon) return self.output """ Internal methods: """ def _jitter(self, color): """ Randomly modifies given color to produce a slightly different color than the color given. Args: color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color picked. The values in the list are in the [0.0, 1.0] range. Returns: jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color after being jittered. The values in the list are in the [0.0, 1.0] range. """ color = mplc.to_rgb(color) # np.random.seed(0) vec = np.random.rand(3) # better to do it in another color space vec = vec / np.linalg.norm(vec) * 0.5 res = np.clip(vec + color, 0, 1) return tuple(res) def _create_grayscale_image(self, mask=None): """ Create a grayscale version of the original image. The colors in masked area, if given, will be kept. """ img_bw = self.img.astype("f4").mean(axis=2) img_bw = np.stack([img_bw] * 3, axis=2) if mask is not None: img_bw[mask] = self.img[mask] return img_bw def _change_color_brightness(self, color, brightness_factor): """ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with less or more saturation than the original color. Args: color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of 0 will correspond to no change, a factor in [-1.0, 0) range will result in a darker color and a factor in (0, 1.0] range will result in a lighter color. Returns: modified_color (tuple[double]): a tuple containing the RGB values of the modified color. Each value in the tuple is in the [0.0, 1.0] range. """ assert brightness_factor >= -1.0 and brightness_factor <= 1.0 color = mplc.to_rgb(color) polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) return modified_color def _convert_boxes(self, boxes): """ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. """ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): return boxes.tensor.detach().numpy() else: return np.asarray(boxes) def _convert_masks(self, masks_or_polygons): """ Convert different format of masks or polygons to a tuple of masks and polygons. Returns: list[GenericMask]: """ m = masks_or_polygons if isinstance(m, PolygonMasks): m = m.polygons if isinstance(m, BitMasks): m = m.tensor.numpy() if isinstance(m, torch.Tensor): m = m.numpy() ret = [] for x in m: if isinstance(x, GenericMask): ret.append(x) else: ret.append(GenericMask(x, self.output.height, self.output.width)) return ret def _draw_text_in_mask(self, binary_mask, text, color): """ Find proper places to draw text given a binary mask. """ # TODO sometimes drawn on wrong objects. the heuristics here can improve. _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) if stats[1:, -1].size == 0: return largest_component_id = np.argmax(stats[1:, -1]) + 1 # draw text on the largest component, as well as other very large components. for cid in range(1, _num_cc): if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: # median is more stable than centroid # center = centroids[largest_component_id] center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] self.draw_text(text, center, color=color) def _convert_keypoints(self, keypoints): if isinstance(keypoints, Keypoints): keypoints = keypoints.tensor keypoints = np.asarray(keypoints) return keypoints def get_output(self): """ Returns: output (VisImage): the image output containing the visualizations added to the image. """ return self.output